@@ -308,23 +281,22 @@ make_image_grid([original_image, canny_image, image], rows=1, cols=3)
## ControlNet with Stable Diffusion XL
-There aren't too many ControlNet models compatible with Stable Diffusion XL (SDXL) at the moment, but we've trained two full-sized ControlNet models for SDXL conditioned on canny edge detection and depth maps. We're also experimenting with creating smaller versions of these SDXL-compatible ControlNet models so it is easier to run on resource-constrained hardware. You can find these checkpoints on the [🤗 Diffusers Hub organization](https://huggingface.co/diffusers)!
+There aren't too many ControlNet models compatible with Stable Diffusion XL (SDXL) at the moment, but we've trained two full-sized ControlNet models for SDXL conditioned on canny edge detection and depth maps. We're also experimenting with creating smaller versions of these SDXL-compatible ControlNet models so it is easier to run on resource-constrained hardware. You can find these checkpoints on the 🤗 [Diffusers](https://huggingface.co/diffusers) Hub organization!
Let's use a SDXL ControlNet conditioned on canny images to generate an image. Start by loading an image and prepare the canny image:
```py
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
from PIL import Image
import cv2
import numpy as np
-import torch
-original_image = load_image(
+image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
)
-image = np.array(original_image)
+image = np.array(image)
low_threshold = 100
high_threshold = 200
@@ -333,7 +305,7 @@ image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
-make_image_grid([original_image, canny_image], rows=1, cols=2)
+canny_image
```
@@ -378,13 +350,13 @@ The [`controlnet_conditioning_scale`](https://huggingface.co/docs/diffusers/main
prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
negative_prompt = 'low quality, bad quality, sketches'
-image = pipe(
- prompt,
- negative_prompt=negative_prompt,
- image=canny_image,
+images = pipe(
+ prompt,
+ negative_prompt=negative_prompt,
+ image=image,
controlnet_conditioning_scale=0.5,
).images[0]
-make_image_grid([original_image, canny_image, image], rows=1, cols=3)
+images
```
@@ -395,16 +367,17 @@ You can use [`StableDiffusionXLControlNetPipeline`] in guess mode as well by set
```py
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
import numpy as np
import torch
+
import cv2
from PIL import Image
prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
negative_prompt = "low quality, bad quality, sketches"
-original_image = load_image(
+image = load_image(
"https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
)
@@ -417,16 +390,15 @@ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
)
pipe.enable_model_cpu_offload()
-image = np.array(original_image)
+image = np.array(image)
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
image = pipe(
- prompt, negative_prompt=negative_prompt, controlnet_conditioning_scale=0.5, image=canny_image, guess_mode=True,
+ prompt, controlnet_conditioning_scale=0.5, image=canny_image, guess_mode=True,
).images[0]
-make_image_grid([original_image, canny_image, image], rows=1, cols=3)
```
### MultiControlNet
@@ -447,30 +419,29 @@ In this example, you'll combine a canny image and a human pose estimation image
Prepare the canny image conditioning:
```py
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
from PIL import Image
-import numpy as np
+import numpy as np
import cv2
-original_image = load_image(
+canny_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/landscape.png"
)
-image = np.array(original_image)
+canny_image = np.array(canny_image)
low_threshold = 100
high_threshold = 200
-image = cv2.Canny(image, low_threshold, high_threshold)
+canny_image = cv2.Canny(canny_image, low_threshold, high_threshold)
-# zero out middle columns of image where pose will be overlaid
-zero_start = image.shape[1] // 4
-zero_end = zero_start + image.shape[1] // 2
-image[:, zero_start:zero_end] = 0
+# zero out middle columns of image where pose will be overlayed
+zero_start = canny_image.shape[1] // 4
+zero_end = zero_start + canny_image.shape[1] // 2
+canny_image[:, zero_start:zero_end] = 0
-image = image[:, :, None]
-image = np.concatenate([image, image, image], axis=2)
-canny_image = Image.fromarray(image)
-make_image_grid([original_image, canny_image], rows=1, cols=2)
+canny_image = canny_image[:, :, None]
+canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
+canny_image = Image.fromarray(canny_image).resize((1024, 1024))
```
@@ -484,24 +455,18 @@ make_image_grid([original_image, canny_image], rows=1, cols=2)
-For human pose estimation, install [controlnet_aux](https://github.com/patrickvonplaten/controlnet_aux):
-
-```py
-# uncomment to install the necessary library in Colab
-#!pip install -q controlnet-aux
-```
-
Prepare the human pose estimation conditioning:
```py
from controlnet_aux import OpenposeDetector
+from diffusers.utils import load_image
openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
-original_image = load_image(
+
+openpose_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/person.png"
)
-openpose_image = openpose(original_image)
-make_image_grid([original_image, openpose_image], rows=1, cols=2)
+openpose_image = openpose(openpose_image).resize((1024, 1024))
```
@@ -523,7 +488,7 @@ import torch
controlnets = [
ControlNetModel.from_pretrained(
- "thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=torch.float16
+ "thibaud/controlnet-openpose-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True
),
ControlNetModel.from_pretrained(
"diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16, use_safetensors=True
@@ -546,7 +511,7 @@ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
generator = torch.manual_seed(1)
-images = [openpose_image.resize((1024, 1024)), canny_image.resize((1024, 1024))]
+images = [openpose_image, canny_image]
images = pipe(
prompt,
@@ -556,11 +521,9 @@ images = pipe(
negative_prompt=negative_prompt,
num_images_per_prompt=3,
controlnet_conditioning_scale=[1.0, 0.8],
-).images
-make_image_grid([original_image, canny_image, openpose_image,
- images[0].resize((512, 512)), images[1].resize((512, 512)), images[2].resize((512, 512))], rows=2, cols=3)
+).images[0]
```

-
+
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/custom_pipeline_examples.md b/docs/source/en/using-diffusers/custom_pipeline_examples.md
index e0d3182f3e8a..2f47d1b26c6c 100644
--- a/docs/source/en/using-diffusers/custom_pipeline_examples.md
+++ b/docs/source/en/using-diffusers/custom_pipeline_examples.md
@@ -14,106 +14,273 @@ specific language governing permissions and limitations under the License.
[[open-in-colab]]
-
+> **For more information about community pipelines, please have a look at [this issue](https://github.com/huggingface/diffusers/issues/841).**
-For more context about the design choices behind community pipelines, please have a look at [this issue](https://github.com/huggingface/diffusers/issues/841).
+**Community** examples consist of both inference and training examples that have been added by the community.
+Please have a look at the following table to get an overview of all community examples. Click on the **Code Example** to get a copy-and-paste ready code example that you can try out.
+If a community doesn't work as expected, please open an issue and ping the author on it.
-
-
-Community pipelines allow you to get creative and build your own unique pipelines to share with the community. You can find all community pipelines in the [diffusers/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) folder along with inference and training examples for how to use them. This guide showcases some of the community pipelines and hopefully it'll inspire you to create your own (feel free to open a PR with your own pipeline and we will merge it!).
-
-To load a community pipeline, use the `custom_pipeline` argument in [`DiffusionPipeline`] to specify one of the files in [diffusers/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community):
+| Example | Description | Code Example | Colab | Author |
+|:---------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------:|
+| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) |
+| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
+| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) |
+| Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
+| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) |
+| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech)
+To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
-from diffusers import DiffusionPipeline
-
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", custom_pipeline="filename_in_the_community_folder", use_safetensors=True
)
```
-If a community pipeline doesn't work as expected, please open a GitHub issue and mention the author.
+## Example usages
-You can learn more about community pipelines in the how to [load community pipelines](custom_pipeline_overview) and how to [contribute a community pipeline](contribute_pipeline) guides.
+### CLIP Guided Stable Diffusion
-## Multilingual Stable Diffusion
+CLIP guided stable diffusion can help to generate more realistic images
+by guiding stable diffusion at every denoising step with an additional CLIP model.
-The multilingual Stable Diffusion pipeline uses a pretrained [XLM-RoBERTa](https://huggingface.co/papluca/xlm-roberta-base-language-detection) to identify a language and the [mBART-large-50](https://huggingface.co/facebook/mbart-large-50-many-to-one-mmt) model to handle the translation. This allows you to generate images from text in 20 languages.
+The following code requires roughly 12GB of GPU RAM.
-```py
+```python
+from diffusers import DiffusionPipeline
+from transformers import CLIPImageProcessor, CLIPModel
import torch
+
+
+feature_extractor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
+clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16)
+
+
+guided_pipeline = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="clip_guided_stable_diffusion",
+ clip_model=clip_model,
+ feature_extractor=feature_extractor,
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+)
+guided_pipeline.enable_attention_slicing()
+guided_pipeline = guided_pipeline.to("cuda")
+
+prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece"
+
+generator = torch.Generator(device="cuda").manual_seed(0)
+images = []
+for i in range(4):
+ image = guided_pipeline(
+ prompt,
+ num_inference_steps=50,
+ guidance_scale=7.5,
+ clip_guidance_scale=100,
+ num_cutouts=4,
+ use_cutouts=False,
+ generator=generator,
+ ).images[0]
+ images.append(image)
+
+# save images locally
+for i, img in enumerate(images):
+ img.save(f"./clip_guided_sd/image_{i}.png")
+```
+
+The `images` list contains a list of PIL images that can be saved locally or displayed directly in a google colab.
+Generated images tend to be of higher qualtiy than natively using stable diffusion. E.g. the above script generates the following images:
+
+.
+
+### One Step Unet
+
+The dummy "one-step-unet" can be run as follows:
+
+```python
from diffusers import DiffusionPipeline
-from diffusers.utils import make_image_grid
-from transformers import (
- pipeline,
- MBart50TokenizerFast,
- MBartForConditionalGeneration,
+
+pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="one_step_unet")
+pipe()
+```
+
+**Note**: This community pipeline is not useful as a feature, but rather just serves as an example of how community pipelines can be added (see https://github.com/huggingface/diffusers/issues/841).
+
+### Stable Diffusion Interpolation
+
+The following code can be run on a GPU of at least 8GB VRAM and should take approximately 5 minutes.
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ torch_dtype=torch.float16,
+ safety_checker=None, # Very important for videos...lots of false positives while interpolating
+ custom_pipeline="interpolate_stable_diffusion",
+ use_safetensors=True,
+).to("cuda")
+pipe.enable_attention_slicing()
+
+frame_filepaths = pipe.walk(
+ prompts=["a dog", "a cat", "a horse"],
+ seeds=[42, 1337, 1234],
+ num_interpolation_steps=16,
+ output_dir="./dreams",
+ batch_size=4,
+ height=512,
+ width=512,
+ guidance_scale=8.5,
+ num_inference_steps=50,
)
+```
-device = "cuda" if torch.cuda.is_available() else "cpu"
-device_dict = {"cuda": 0, "cpu": -1}
+The output of the `walk(...)` function returns a list of images saved under the folder as defined in `output_dir`. You can use these images to create videos of stable diffusion.
-# add language detection pipeline
-language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection"
-language_detection_pipeline = pipeline("text-classification",
- model=language_detection_model_ckpt,
- device=device_dict[device])
+> **Please have a look at https://github.com/nateraw/stable-diffusion-videos for more in-detail information on how to create videos using stable diffusion as well as more feature-complete functionality.**
-# add model for language translation
-translation_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
-translation_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device)
+### Stable Diffusion Mega
-diffuser_pipeline = DiffusionPipeline.from_pretrained(
+The Stable Diffusion Mega Pipeline lets you use the main use cases of the stable diffusion pipeline in a single class.
+
+```python
+#!/usr/bin/env python3
+from diffusers import DiffusionPipeline
+import PIL
+import requests
+from io import BytesIO
+import torch
+
+
+def download_image(url):
+ response = requests.get(url)
+ return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
- custom_pipeline="multilingual_stable_diffusion",
- detection_pipeline=language_detection_pipeline,
- translation_model=translation_model,
- translation_tokenizer=translation_tokenizer,
+ custom_pipeline="stable_diffusion_mega",
torch_dtype=torch.float16,
+ use_safetensors=True,
)
+pipe.to("cuda")
+pipe.enable_attention_slicing()
-diffuser_pipeline.enable_attention_slicing()
-diffuser_pipeline = diffuser_pipeline.to(device)
-prompt = ["a photograph of an astronaut riding a horse",
- "Una casa en la playa",
- "Ein Hund, der Orange isst",
- "Un restaurant parisien"]
+### Text-to-Image
+
+images = pipe.text2img("An astronaut riding a horse").images
-images = diffuser_pipeline(prompt).images
-make_image_grid(images, rows=2, cols=2)
+### Image-to-Image
+
+init_image = download_image(
+ "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+)
+
+prompt = "A fantasy landscape, trending on artstation"
+
+images = pipe.img2img(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
+
+### Inpainting
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+init_image = download_image(img_url).resize((512, 512))
+mask_image = download_image(mask_url).resize((512, 512))
+
+prompt = "a cat sitting on a bench"
+images = pipe.inpaint(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.75).images
```
-
-

-
+As shown above this one pipeline can run all both "text-to-image", "image-to-image", and "inpainting" in one pipeline.
-## MagicMix
+### Long Prompt Weighting Stable Diffusion
-[MagicMix](https://huggingface.co/papers/2210.16056) is a pipeline that can mix an image and text prompt to generate a new image that preserves the image structure. The `mix_factor` determines how much influence the prompt has on the layout generation, `kmin` controls the number of steps during the content generation process, and `kmax` determines how much information is kept in the layout of the original image.
+The Pipeline lets you input prompt without 77 token length limit. And you can increase words weighting by using "()" or decrease words weighting by using "[]"
+The Pipeline also lets you use the main use cases of the stable diffusion pipeline in a single class.
-```py
-from diffusers import DiffusionPipeline, DDIMScheduler
-from diffusers.utils import load_image, make_image_grid
+#### pytorch
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "hakurei/waifu-diffusion", custom_pipeline="lpw_stable_diffusion", torch_dtype=torch.float16, use_safetensors=True
+)
+pipe = pipe.to("cuda")
+
+prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms"
+neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry"
+
+pipe.text2img(prompt, negative_prompt=neg_prompt, width=512, height=512, max_embeddings_multiples=3).images[0]
+```
+
+#### onnxruntime
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipe = DiffusionPipeline.from_pretrained(
+ "CompVis/stable-diffusion-v1-4",
+ custom_pipeline="lpw_stable_diffusion_onnx",
+ revision="onnx",
+ provider="CUDAExecutionProvider",
+ use_safetensors=True,
+)
-pipeline = DiffusionPipeline.from_pretrained(
+prompt = "a photo of an astronaut riding a horse on mars, best quality"
+neg_prompt = "lowres, bad anatomy, error body, error hair, error arm, error hands, bad hands, error fingers, bad fingers, missing fingers, error legs, bad legs, multiple legs, missing legs, error lighting, error shadow, error reflection, text, error, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
+
+pipe.text2img(prompt, negative_prompt=neg_prompt, width=512, height=512, max_embeddings_multiples=3).images[0]
+```
+
+if you see `Token indices sequence length is longer than the specified maximum sequence length for this model ( *** > 77 ) . Running this sequence through the model will result in indexing errors`. Do not worry, it is normal.
+
+### Speech to Image
+
+The following code can generate an image from an audio sample using pre-trained OpenAI whisper-small and Stable Diffusion.
+
+```Python
+import torch
+
+import matplotlib.pyplot as plt
+from datasets import load_dataset
+from diffusers import DiffusionPipeline
+from transformers import (
+ WhisperForConditionalGeneration,
+ WhisperProcessor,
+)
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+
+audio_sample = ds[3]
+
+text = audio_sample["text"].lower()
+speech_data = audio_sample["audio"]["array"]
+
+model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
+processor = WhisperProcessor.from_pretrained("openai/whisper-small")
+
+diffuser_pipeline = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
- custom_pipeline="magic_mix",
- scheduler=DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
-).to('cuda')
+ custom_pipeline="speech_to_image_diffusion",
+ speech_model=model,
+ speech_processor=processor,
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+)
+
+diffuser_pipeline.enable_attention_slicing()
+diffuser_pipeline = diffuser_pipeline.to(device)
-img = load_image("https://user-images.githubusercontent.com/59410571/209578593-141467c7-d831-4792-8b9a-b17dc5e47816.jpg")
-mix_img = pipeline(img, prompt="bed", kmin=0.3, kmax=0.5, mix_factor=0.5)
-make_image_grid([img, mix_img], rows=1, cols=2)
+output = diffuser_pipeline(speech_data)
+plt.imshow(output.images[0])
```
+This example produces the following image:
-
-
-

-
original image
-
-
-

-
image and text prompt mix
-
-
+
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/custom_pipeline_overview.md b/docs/source/en/using-diffusers/custom_pipeline_overview.md
index 0f842c1b5b50..ddab47cc6adf 100644
--- a/docs/source/en/using-diffusers/custom_pipeline_overview.md
+++ b/docs/source/en/using-diffusers/custom_pipeline_overview.md
@@ -10,12 +10,10 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Load community pipelines and components
+# Load community pipelines
[[open-in-colab]]
-## Community pipelines
-
Community pipelines are any [`DiffusionPipeline`] class that are different from the original implementation as specified in their paper (for example, the [`StableDiffusionControlNetPipeline`] corresponds to the [Text-to-Image Generation with ControlNet Conditioning](https://arxiv.org/abs/2302.05543) paper). They provide additional functionality or extend the original implementation of a pipeline.
There are many cool community pipelines like [Speech to Image](https://github.com/huggingface/diffusers/tree/main/examples/community#speech-to-image) or [Composable Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#composable-stable-diffusion), and you can find all the official community pipelines [here](https://github.com/huggingface/diffusers/tree/main/examples/community).
@@ -56,134 +54,4 @@ pipeline = DiffusionPipeline.from_pretrained(
)
```
-For more information about community pipelines, take a look at the [Community pipelines](custom_pipeline_examples) guide for how to use them and if you're interested in adding a community pipeline check out the [How to contribute a community pipeline](contribute_pipeline) guide!
-
-## Community components
-
-Community components allow users to build pipelines that may have customized components that are not a part of Diffusers. If your pipeline has custom components that Diffusers doesn't already support, you need to provide their implementations as Python modules. These customized components could be a VAE, UNet, and scheduler. In most cases, the text encoder is imported from the Transformers library. The pipeline code itself can also be customized.
-
-This section shows how users should use community components to build a community pipeline.
-
-You'll use the [showlab/show-1-base](https://huggingface.co/showlab/show-1-base) pipeline checkpoint as an example. So, let's start loading the components:
-
-1. Import and load the text encoder from Transformers:
-
-```python
-from transformers import T5Tokenizer, T5EncoderModel
-
-pipe_id = "showlab/show-1-base"
-tokenizer = T5Tokenizer.from_pretrained(pipe_id, subfolder="tokenizer")
-text_encoder = T5EncoderModel.from_pretrained(pipe_id, subfolder="text_encoder")
-```
-
-2. Load a scheduler:
-
-```python
-from diffusers import DPMSolverMultistepScheduler
-
-scheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder="scheduler")
-```
-
-3. Load an image processor:
-
-```python
-from transformers import CLIPFeatureExtractor
-
-feature_extractor = CLIPFeatureExtractor.from_pretrained(pipe_id, subfolder="feature_extractor")
-```
-
-
-
-In steps 4 and 5, the custom [UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py) and [pipeline](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) implementation must match the format shown in their files for this example to work.
-
-
-
-4. Now you'll load a [custom UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py), which in this example, has already been implemented in the `showone_unet_3d_condition.py` [script](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) for your convenience. You'll notice the `UNet3DConditionModel` class name is changed to `ShowOneUNet3DConditionModel` because [`UNet3DConditionModel`] already exists in Diffusers. Any components needed for the `ShowOneUNet3DConditionModel` class should be placed in the `showone_unet_3d_condition.py` script.
-
-Once this is done, you can initialize the UNet:
-
-```python
-from showone_unet_3d_condition import ShowOneUNet3DConditionModel
-
-unet = ShowOneUNet3DConditionModel.from_pretrained(pipe_id, subfolder="unet")
-```
-
-5. Finally, you'll load the custom pipeline code. For this example, it has already been created for you in the `pipeline_t2v_base_pixel.py` [script](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/pipeline_t2v_base_pixel.py). This script contains a custom `TextToVideoIFPipeline` class for generating videos from text. Just like the custom UNet, any code needed for the custom pipeline to work should go in the `pipeline_t2v_base_pixel.py` script.
-
-Once everything is in place, you can initialize the `TextToVideoIFPipeline` with the `ShowOneUNet3DConditionModel`:
-
-```python
-from pipeline_t2v_base_pixel import TextToVideoIFPipeline
-import torch
-
-pipeline = TextToVideoIFPipeline(
- unet=unet,
- text_encoder=text_encoder,
- tokenizer=tokenizer,
- scheduler=scheduler,
- feature_extractor=feature_extractor
-)
-pipeline = pipeline.to(device="cuda")
-pipeline.torch_dtype = torch.float16
-```
-
-Push the pipeline to the Hub to share with the community!
-
-```python
-pipeline.push_to_hub("custom-t2v-pipeline")
-```
-
-After the pipeline is successfully pushed, you need a couple of changes:
-
-1. Change the `_class_name` attribute in [`model_index.json`](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/model_index.json#L2) to `"pipeline_t2v_base_pixel"` and `"TextToVideoIFPipeline"`.
-2. Upload `showone_unet_3d_condition.py` to the `unet` [directory](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py).
-3. Upload `pipeline_t2v_base_pixel.py` to the pipeline base [directory](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py).
-
-To run inference, simply add the `trust_remote_code` argument while initializing the pipeline to handle all the "magic" behind the scenes.
-
-```python
-from diffusers import DiffusionPipeline
-import torch
-
-pipeline = DiffusionPipeline.from_pretrained(
- "
/", trust_remote_code=True, torch_dtype=torch.float16
-).to("cuda")
-
-prompt = "hello"
-
-# Text embeds
-prompt_embeds, negative_embeds = pipeline.encode_prompt(prompt)
-
-# Keyframes generation (8x64x40, 2fps)
-video_frames = pipeline(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_embeds,
- num_frames=8,
- height=40,
- width=64,
- num_inference_steps=2,
- guidance_scale=9.0,
- output_type="pt"
-).frames
-```
-
-As an additional reference example, you can refer to the repository structure of [stabilityai/japanese-stable-diffusion-xl](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl/), that makes use of the `trust_remote_code` feature:
-
-```python
-
-from diffusers import DiffusionPipeline
-import torch
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/japanese-stable-diffusion-xl", trust_remote_code=True
-)
-pipeline.to("cuda")
-
-# if using torch < 2.0
-# pipeline.enable_xformers_memory_efficient_attention()
-
-prompt = "柴犬、カラフルアート"
-
-image = pipeline(prompt=prompt).images[0]
-
-```
\ No newline at end of file
+For more information about community pipelines, take a look at the [Community pipelines](custom_pipeline_examples) guide for how to use them and if you're interested in adding a community pipeline check out the [How to contribute a community pipeline](contribute_pipeline) guide!
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/depth2img.md b/docs/source/en/using-diffusers/depth2img.md
index 84c613b0dade..0a6df2258235 100644
--- a/docs/source/en/using-diffusers/depth2img.md
+++ b/docs/source/en/using-diffusers/depth2img.md
@@ -20,10 +20,12 @@ Start by creating an instance of the [`StableDiffusionDepth2ImgPipeline`]:
```python
import torch
+import requests
+from PIL import Image
+
from diffusers import StableDiffusionDepth2ImgPipeline
-from diffusers.utils import load_image, make_image_grid
-pipeline = StableDiffusionDepth2ImgPipeline.from_pretrained(
+pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-depth",
torch_dtype=torch.float16,
use_safetensors=True,
@@ -34,13 +36,22 @@ Now pass your prompt to the pipeline. You can also pass a `negative_prompt` to p
```python
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
-init_image = load_image(url)
+init_image = Image.open(requests.get(url, stream=True).raw)
prompt = "two tigers"
-negative_prompt = "bad, deformed, ugly, bad anatomy"
-image = pipeline(prompt=prompt, image=init_image, negative_prompt=negative_prompt, strength=0.7).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
+n_prompt = "bad, deformed, ugly, bad anatomy"
+image = pipe(prompt=prompt, image=init_image, negative_prompt=n_prompt, strength=0.7).images[0]
+image
```
| Input | Output |
|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|
|
|
|
+
+Play around with the Spaces below and see if you notice a difference between generated images with and without a depth map!
+
+
diff --git a/docs/source/en/using-diffusers/diffedit.md b/docs/source/en/using-diffusers/diffedit.md
index 1c3793177ce1..4c32eb4c482b 100644
--- a/docs/source/en/using-diffusers/diffedit.md
+++ b/docs/source/en/using-diffusers/diffedit.md
@@ -1,15 +1,3 @@
-
-
# DiffEdit
[[open-in-colab]]
@@ -26,7 +14,7 @@ Before you begin, make sure you have the following libraries installed:
```py
# uncomment to install the necessary libraries in Colab
-#!pip install -q diffusers transformers accelerate
+#!pip install diffusers transformers accelerate safetensors
```
The [`StableDiffusionDiffEditPipeline`] requires an image mask and a set of partially inverted latents. The image mask is generated from the [`~StableDiffusionDiffEditPipeline.generate_mask`] function, and includes two parameters, `source_prompt` and `target_prompt`. These parameters determine what to edit in the image. For example, if you want to change a bowl of *fruits* to a bowl of *pears*, then:
@@ -59,18 +47,15 @@ pipeline.enable_vae_slicing()
Load the image to edit:
```py
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
-raw_image = load_image(img_url).resize((768, 768))
-raw_image
+raw_image = load_image(img_url).convert("RGB").resize((768, 768))
```
Use the [`~StableDiffusionDiffEditPipeline.generate_mask`] function to generate the image mask. You'll need to pass it the `source_prompt` and `target_prompt` to specify what to edit in the image:
```py
-from PIL import Image
-
source_prompt = "a bowl of fruits"
target_prompt = "a basket of pears"
mask_image = pipeline.generate_mask(
@@ -78,7 +63,6 @@ mask_image = pipeline.generate_mask(
source_prompt=source_prompt,
target_prompt=target_prompt,
)
-Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L").resize((768, 768))
```
Next, create the inverted latents and pass it a caption describing the image:
@@ -90,14 +74,13 @@ inv_latents = pipeline.invert(prompt=source_prompt, image=raw_image).latents
Finally, pass the image mask and inverted latents to the pipeline. The `target_prompt` becomes the `prompt` now, and the `source_prompt` is used as the `negative_prompt`:
```py
-output_image = pipeline(
+image = pipeline(
prompt=target_prompt,
mask_image=mask_image,
image_latents=inv_latents,
negative_prompt=source_prompt,
).images[0]
-mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L").resize((768, 768))
-make_image_grid([raw_image, mask_image, output_image], rows=1, cols=3)
+image.save("edited_image.png")
```
@@ -121,8 +104,8 @@ Load the Flan-T5 model and tokenizer from the 🤗 Transformers library:
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
-tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
-model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto", torch_dtype=torch.float16)
+tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
+model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16)
```
Provide some initial text to prompt the model to generate the source and target prompts.
@@ -141,7 +124,7 @@ target_text = f"Provide a caption for images containing a {target_concept}. "
Next, create a utility function to generate the prompts:
```py
-@torch.no_grad()
+@torch.no_grad
def generate_prompts(input_prompt):
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to("cuda")
@@ -165,12 +148,12 @@ Check out the [generation strategy](https://huggingface.co/docs/transformers/mai
Load the text encoder model used by the [`StableDiffusionDiffEditPipeline`] to encode the text. You'll use the text encoder to compute the text embeddings:
```py
-import torch
-from diffusers import StableDiffusionDiffEditPipeline
+import torch
+from diffusers import StableDiffusionDiffEditPipeline
pipeline = StableDiffusionDiffEditPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, use_safetensors=True
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()
@@ -198,39 +181,33 @@ Finally, pass the embeddings to the [`~StableDiffusionDiffEditPipeline.generate_
```diff
from diffusers import DDIMInverseScheduler, DDIMScheduler
- from diffusers.utils import load_image, make_image_grid
- from PIL import Image
+ from diffusers.utils import load_image
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
- raw_image = load_image(img_url).resize((768, 768))
+ raw_image = load_image(img_url).convert("RGB").resize((768, 768))
+
mask_image = pipeline.generate_mask(
image=raw_image,
-- source_prompt=source_prompt,
-- target_prompt=target_prompt,
+ source_prompt_embeds=source_embeds,
+ target_prompt_embeds=target_embeds,
)
inv_latents = pipeline.invert(
-- prompt=source_prompt,
+ prompt_embeds=source_embeds,
image=raw_image,
).latents
- output_image = pipeline(
+ images = pipeline(
mask_image=mask_image,
image_latents=inv_latents,
-- prompt=target_prompt,
-- negative_prompt=source_prompt,
+ prompt_embeds=target_embeds,
+ negative_prompt_embeds=source_embeds,
- ).images[0]
- mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L")
- make_image_grid([raw_image, mask_image, output_image], rows=1, cols=3)
+ ).images
+ images[0].save("edited_image.png")
```
## Generate a caption for inversion
@@ -271,7 +248,7 @@ Load an input image and generate a caption for it using the `generate_caption` f
from diffusers.utils import load_image
img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
-raw_image = load_image(img_url).resize((768, 768))
+raw_image = load_image(img_url).convert("RGB").resize((768, 768))
caption = generate_caption(raw_image, model, processor)
```
@@ -282,4 +259,4 @@ caption = generate_caption(raw_image, model, processor)
-Now you can drop the caption into the [`~StableDiffusionDiffEditPipeline.invert`] function to generate the partially inverted latents!
+Now you can drop the caption into the [`~StableDiffusionDiffEditPipeline.invert`] function to generate the partially inverted latents!
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/distilled_sd.md b/docs/source/en/using-diffusers/distilled_sd.md
index 2dd96d98861d..7653300b92ab 100644
--- a/docs/source/en/using-diffusers/distilled_sd.md
+++ b/docs/source/en/using-diffusers/distilled_sd.md
@@ -1,15 +1,3 @@
-
-
# Distilled Stable Diffusion inference
[[open-in-colab]]
diff --git a/docs/source/en/using-diffusers/freeu.md b/docs/source/en/using-diffusers/freeu.md
index 6e8f5773cd75..6c23ec754382 100644
--- a/docs/source/en/using-diffusers/freeu.md
+++ b/docs/source/en/using-diffusers/freeu.md
@@ -1,37 +1,25 @@
-
-
# Improve generation quality with FreeU
[[open-in-colab]]
-The UNet is responsible for denoising during the reverse diffusion process, and there are two distinct features in its architecture:
+The UNet is responsible for denoising during the reverse diffusion process, and there are two distinct features in its architecture:
1. Backbone features primarily contribute to the denoising process
2. Skip features mainly introduce high-frequency features into the decoder module and can make the network overlook the semantics in the backbone features
-However, the skip connection can sometimes introduce unnatural image details. [FreeU](https://hf.co/papers/2309.11497) is a technique for improving image quality by rebalancing the contributions from the UNet’s skip connections and backbone feature maps.
+However, the skip connection can sometimes introduce unnatural image details. [FreeU](https://hf.co/papers/2309.11497) is a technique for improving image quality by rebalancing the contributions from the UNet’s skip connections and backbone feature maps.
FreeU is applied during inference and it does not require any additional training. The technique works for different tasks such as text-to-image, image-to-image, and text-to-video.
-In this guide, you will apply FreeU to the [`StableDiffusionPipeline`], [`StableDiffusionXLPipeline`], and [`TextToVideoSDPipeline`]. You need to install Diffusers from source to run the examples below.
+In this guide, you will apply FreeU to the [`StableDiffusionPipeline`], [`StableDiffusionXLPipeline`], and [`TextToVideoSDPipeline`].
## StableDiffusionPipeline
-Load the pipeline:
+Load the pipeline:
```py
from diffusers import DiffusionPipeline
-import torch
+import torch
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None
@@ -58,7 +46,6 @@ And then run inference:
prompt = "A squirrel eating a burger"
seed = 2023
image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
-image
```
The figure below compares non-FreeU and FreeU results respectively for the same hyperparameters used above (`prompt` and `seed`):
@@ -70,7 +57,7 @@ Let's see how Stable Diffusion 2 results are impacted:
```py
from diffusers import DiffusionPipeline
-import torch
+import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, safety_checker=None
@@ -81,9 +68,9 @@ seed = 2023
pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
-image
```
+

## Stable Diffusion XL
@@ -92,7 +79,7 @@ Finally, let's take a look at how FreeU affects Stable Diffusion XL results:
```py
from diffusers import DiffusionPipeline
-import torch
+import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16,
@@ -101,13 +88,13 @@ pipeline = DiffusionPipeline.from_pretrained(
prompt = "A squirrel eating a burger"
seed = 2023
-# Comes from
+# Comes from
# https://wandb.ai/nasirk24/UNET-FreeU-SDXL/reports/FreeU-SDXL-Optimal-Parameters--Vmlldzo1NDg4NTUw
pipeline.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
-image
```
+

## Text-to-video generation
@@ -120,7 +107,8 @@ from diffusers.utils import export_to_video
import torch
model_id = "cerspense/zeroscope_v2_576w"
-pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16).to("cuda")
+pipe = pipe.to("cuda")
prompt = "an astronaut riding a horse on mars"
seed = 2023
@@ -132,4 +120,4 @@ video_frames = pipe(prompt, height=320, width=576, num_frames=30, generator=torc
export_to_video(video_frames, "astronaut_rides_horse.mp4")
```
-Thanks to [kadirnar](https://github.com/kadirnar/) for helping to integrate the feature, and to [justindujardin](https://github.com/justindujardin) for the helpful discussions.
+Thanks to [kadirnar](https://github.com/kadirnar/) for helping to integrate the feature, and to [justindujardin](https://github.com/justindujardin) for the helpful discussions.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/img2img.md b/docs/source/en/using-diffusers/img2img.md
index 6014d87b7906..a1a4733a514c 100644
--- a/docs/source/en/using-diffusers/img2img.md
+++ b/docs/source/en/using-diffusers/img2img.md
@@ -21,15 +21,13 @@ With 🤗 Diffusers, this is as easy as 1-2-3:
1. Load a checkpoint into the [`AutoPipelineForImage2Image`] class; this pipeline automatically handles loading the correct pipeline class based on the checkpoint:
```py
-import torch
from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
- "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
-)
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
```
@@ -50,7 +48,7 @@ init_image = load_image("https://huggingface.co/datasets/huggingface/documentati
```py
prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
image = pipeline(prompt, image=init_image).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
+image
```
@@ -70,29 +68,31 @@ The most popular image-to-image models are [Stable Diffusion v1.5](https://huggi
### Stable Diffusion v1.5
-Stable Diffusion v1.5 is a latent diffusion model initialized from an earlier checkpoint, and further finetuned for 595K steps on 512x512 images. To use this pipeline for image-to-image, you'll need to prepare an initial image to pass to the pipeline. Then you can pass a prompt and the image to the pipeline to generate a new image:
+Stable Diffusion v1.5 is a latent diffusion model intialized from an earlier checkpoint, and further finetuned for 595K steps on 512x512 images. To use this pipeline for image-to-image, you'll need to prepare an initial image to pass to the pipeline. Then you can pass a prompt and the image to the pipeline to generate a new image:
```py
import torch
+import requests
+from PIL import Image
+from io import BytesIO
from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
+image
```
@@ -112,25 +112,27 @@ SDXL is a more powerful version of the Stable Diffusion model. It uses a larger
```py
import torch
+import requests
+from PIL import Image
+from io import BytesIO
from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png"
-init_image = load_image(url)
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, strength=0.5).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
+image
```
@@ -152,25 +154,27 @@ The simplest way to use Kandinsky 2.2 is:
```py
import torch
+import requests
+from PIL import Image
+from io import BytesIO
from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
- "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
-)
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
+image
```
@@ -195,29 +199,32 @@ There are several important parameters you can configure in the pipeline that'll
- 📈 a higher `strength` value gives the model more "creativity" to generate an image that's different from the initial image; a `strength` value of 1.0 means the initial image is more or less ignored
- 📉 a lower `strength` value means the generated image is more similar to the initial image
-The `strength` and `num_inference_steps` parameters are related because `strength` determines the number of noise steps to add. For example, if the `num_inference_steps` is 50 and `strength` is 0.8, then this means adding 40 (50 * 0.8) steps of noise to the initial image and then denoising for 40 steps to get the newly generated image.
+The `strength` and `num_inference_steps` parameter are related because `strength` determines the number of noise steps to add. For example, if the `num_inference_steps` is 50 and `strength` is 0.8, then this means adding 40 (50 * 0.8) steps of noise to the initial image and then denoising for 40 steps to get the newly generated image.
```py
import torch
+import requests
+from PIL import Image
+from io import BytesIO
from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+image = init_image
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, strength=0.8).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
+image
```
@@ -243,25 +250,27 @@ You can combine `guidance_scale` with `strength` for even more precise control o
```py
import torch
+import requests
+from PIL import Image
+from io import BytesIO
from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, guidance_scale=8.0).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
+image
```
@@ -285,36 +294,38 @@ A negative prompt conditions the model to *not* include things in an image, and
```py
import torch
+import requests
+from PIL import Image
+from io import BytesIO
from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
# pass prompt and image to pipeline
image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
+image
```

-
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
+
negative prompt = "ugly, deformed, disfigured, poor details, bad anatomy"

-
negative_prompt = "jungle"
+
negative prompt = "jungle"
@@ -331,54 +342,52 @@ Start by generating an image with the text-to-image pipeline:
```py
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
import torch
-from diffusers.utils import make_image_grid
pipeline = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-text2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
-text2image
+image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
```
Now you can pass this generated image to the image-to-image pipeline:
```py
pipeline = AutoPipelineForImage2Image.from_pretrained(
- "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
-)
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-image2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=text2image).images[0]
-make_image_grid([text2image, image2image], rows=1, cols=2)
+image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=image).images[0]
+image
```
### Image-to-image-to-image
-You can also chain multiple image-to-image pipelines together to create more interesting images. This can be useful for iteratively performing style transfer on an image, generating short GIFs, restoring color to an image, or restoring missing areas of an image.
+You can also chain multiple image-to-image pipelines together to create more interesting images. This can be useful for iteratively performing style transfer on an image, generate short GIFs, restore color to an image, or restore missing areas of an image.
Start by generating an image:
```py
import torch
+import requests
+from PIL import Image
+from io import BytesIO
from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
@@ -395,11 +404,10 @@ It is important to specify `output_type="latent"` in the pipeline to keep all th
Pass the latent output from this pipeline to the next pipeline to generate an image in a [comic book art style](https://huggingface.co/ogkalu/Comic-Diffusion):
```py
-pipeline = AutoPipelineForImage2Image.from_pretrained(
- "ogkalu/Comic-Diffusion", torch_dtype=torch.float16
-)
+pipelne = AutoPipelineForImage2Image.from_pretrained(
+ "ogkalu/Comic-Diffusion", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# need to include the token "charliebo artstyle" in the prompt to use this checkpoint
@@ -410,15 +418,14 @@ Repeat one more time to generate the final image in a [pixel art style](https://
```py
pipeline = AutoPipelineForImage2Image.from_pretrained(
- "kohbanye/pixel-art-style", torch_dtype=torch.float16
-)
+ "kohbanye/pixel-art-style", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# need to include the token "pixelartstyle" in the prompt to use this checkpoint
image = pipeline("Astronaut in a jungle, pixelartstyle", image=image).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
+image
```
### Image-to-upscaler-to-super-resolution
@@ -429,19 +436,21 @@ Start with an image-to-image pipeline:
```py
import torch
+import requests
+from PIL import Image
+from io import BytesIO
from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
@@ -458,11 +467,9 @@ It is important to specify `output_type="latent"` in the pipeline to keep all th
Chain it to an upscaler pipeline to increase the image resolution:
```py
-from diffusers import StableDiffusionLatentUpscalePipeline
-
-upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
+upscaler = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
+).to("cuda")
upscaler.enable_model_cpu_offload()
upscaler.enable_xformers_memory_efficient_attention()
@@ -472,16 +479,14 @@ image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0]
Finally, chain it to a super-resolution pipeline to further enhance the resolution:
```py
-from diffusers import StableDiffusionUpscalePipeline
-
-super_res = StableDiffusionUpscalePipeline.from_pretrained(
+super_res = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
+).to("cuda")
super_res.enable_model_cpu_offload()
super_res.enable_xformers_memory_efficient_attention()
-image_3 = super_res(prompt, image=image_2).images[0]
-make_image_grid([init_image, image_3.resize((512, 512))], rows=1, cols=2)
+image_3 = upscaler(prompt, image=image_2).images[0]
+image_3
```
## Control image generation
@@ -499,14 +504,13 @@ from diffusers import AutoPipelineForImage2Image
import torch
pipeline = AutoPipelineForImage2Image.from_pretrained(
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-image = pipeline(prompt_embeds=prompt_embeds, # generated from Compel
- negative_prompt_embeds=negative_prompt_embeds, # generated from Compel
+image = pipeline(prompt_emebds=prompt_embeds, # generated from Compel
+ negative_prompt_embeds, # generated from Compel
image=init_image,
).images[0]
```
@@ -518,28 +522,26 @@ ControlNets provide a more flexible and accurate way to control image generation
For example, let's condition an image with a depth map to keep the spatial information in the image.
```py
-from diffusers.utils import load_image, make_image_grid
-
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-init_image = load_image(url)
+response = requests.get(url)
+init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((958, 960)) # resize to depth image dimensions
depth_image = load_image("https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png")
-make_image_grid([init_image, depth_image], rows=1, cols=2)
```
Load a ControlNet model conditioned on depth maps and the [`AutoPipelineForImage2Image`]:
```py
from diffusers import ControlNetModel, AutoPipelineForImage2Image
+from diffusers.utils import load_image
import torch
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
```
@@ -547,8 +549,8 @@ Now generate a new image conditioned on the depth map, initial image, and prompt
```py
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-image_control_net = pipeline(prompt, image=init_image, control_image=depth_image).images[0]
-make_image_grid([init_image, depth_image, image_control_net], rows=1, cols=3)
+image = pipeline(prompt, image=init_image, control_image=depth_image).images[0]
+image
```
@@ -571,16 +573,15 @@ Let's apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion
```py
pipeline = AutoPipelineForImage2Image.from_pretrained(
"nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
prompt = "elden ring style astronaut in a jungle" # include the token "elden ring style" in the prompt
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
-image_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image_control_net, strength=0.45, guidance_scale=10.5).images[0]
-make_image_grid([init_image, depth_image, image_control_net, image_elden_ring], rows=2, cols=2)
+image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image, strength=0.45, guidance_scale=10.5).images[0]
+image
```
@@ -596,10 +597,10 @@ Running diffusion models is computationally expensive and intensive, but with a
+ pipeline.enable_xformers_memory_efficient_attention()
```
-With [`torch.compile`](../optimization/torch2.0#torchcompile), you can boost your inference speed even more by wrapping your UNet with it:
+With [`torch.compile`](../optimization/torch2.0#torch.compile), you can boost your inference speed even more by wrapping your UNet with it:
```py
-pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
+pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
```
To learn more, take a look at the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
diff --git a/docs/source/en/using-diffusers/inpaint.md b/docs/source/en/using-diffusers/inpaint.md
index e6b1010f13b0..730cddf971a4 100644
--- a/docs/source/en/using-diffusers/inpaint.md
+++ b/docs/source/en/using-diffusers/inpaint.md
@@ -23,13 +23,12 @@ With 🤗 Diffusers, here is how you can do inpainting:
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
pipeline = AutoPipelineForInpainting.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
```
@@ -42,8 +41,8 @@ You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu
2. Load the base and mask images:
```py
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
```
3. Create a prompt to inpaint the image with and pass it to the pipeline with the base and mask images:
@@ -52,7 +51,6 @@ mask_image = load_image("https://huggingface.co/datasets/huggingface/documentati
prompt = "a black cat with glowing eyes, cute, adorable, disney, pixar, highly detailed, 8k"
negative_prompt = "bad anatomy, deformed, ugly, disfigured"
image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]
-make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
@@ -60,10 +58,6 @@ make_image_grid([init_image, mask_image, image], rows=1, cols=3)
base image
-
-

-
mask image
-
generated image
@@ -85,7 +79,7 @@ Upload a base image to inpaint on and use the sketch tool to draw a mask. Once y
## Popular models
-[Stable Diffusion Inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting), [Stable Diffusion XL (SDXL) Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1), and [Kandinsky 2.2 Inpainting](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint) are among the most popular models for inpainting. SDXL typically produces higher resolution images than Stable Diffusion v1.5, and Kandinsky 2.2 is also capable of generating high-quality images.
+[Stable Diffusion Inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting), [Stable Diffusion XL (SDXL) Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint) are among the most popular models for inpainting. SDXL typically produces higher resolution images than Stable Diffusion v1.5, and Kandinsky 2.2 is also capable of generating high-quality images.
### Stable Diffusion Inpainting
@@ -94,23 +88,21 @@ Stable Diffusion Inpainting is a latent diffusion model finetuned on 512x512 ima
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
generator = torch.Generator("cuda").manual_seed(92)
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
-make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
### Stable Diffusion XL (SDXL) Inpainting
@@ -120,23 +112,21 @@ SDXL is a larger and more powerful version of Stable Diffusion v1.5. This model
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
pipeline = AutoPipelineForInpainting.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16"
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
generator = torch.Generator("cuda").manual_seed(92)
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
-make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
### Kandinsky 2.2 Inpainting
@@ -146,23 +136,21 @@ The Kandinsky model family is similar to SDXL because it uses two models as well
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
pipeline = AutoPipelineForInpainting.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
generator = torch.Generator("cuda").manual_seed(92)
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
-make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
@@ -184,183 +172,6 @@ make_image_grid([init_image, mask_image, image], rows=1, cols=3)
-## Non-inpaint specific checkpoints
-
-So far, this guide has used inpaint specific checkpoints such as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting). But you can also use regular checkpoints like [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5). Let's compare the results of the two checkpoints.
-
-The image on the left is generated from a regular checkpoint, and the image on the right is from an inpaint checkpoint. You'll immediately notice the image on the left is not as clean, and you can still see the outline of the area the model is supposed to inpaint. The image on the right is much cleaner and the inpainted area appears more natural.
-
-
-
-
-```py
-import torch
-from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
-
-pipeline = AutoPipelineForInpainting.from_pretrained(
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
-).to("cuda")
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
-
-generator = torch.Generator("cuda").manual_seed(92)
-prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
-image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
-```
-
-
-
-
-```py
-import torch
-from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
-
-pipeline = AutoPipelineForInpainting.from_pretrained(
- "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
-).to("cuda")
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
-
-generator = torch.Generator("cuda").manual_seed(92)
-prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
-image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
-```
-
-
-
-
-
-
-

-
runwayml/stable-diffusion-v1-5
-
-
-

-
runwayml/stable-diffusion-inpainting
-
-
-
-However, for more basic tasks like erasing an object from an image (like the rocks in the road for example), a regular checkpoint yields pretty good results. There isn't as noticeable of difference between the regular and inpaint checkpoint.
-
-
-
-
-```py
-import torch
-from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
-
-pipeline = AutoPipelineForInpainting.from_pretrained(
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
-).to("cuda")
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/road-mask.png")
-
-image = pipeline(prompt="road", image=init_image, mask_image=mask_image).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
-```
-
-
-
-
-```py
-import torch
-from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
-
-pipeline = AutoPipelineForInpainting.from_pretrained(
- "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
-).to("cuda")
-pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
-pipeline.enable_xformers_memory_efficient_attention()
-
-# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/road-mask.png")
-
-image = pipeline(prompt="road", image=init_image, mask_image=mask_image).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
-```
-
-
-
-
-
-
-

-
runwayml/stable-diffusion-v1-5
-
-
-

-
runwayml/stable-diffusion-inpainting
-
-
-
-The trade-off of using a non-inpaint specific checkpoint is the overall image quality may be lower, but it generally tends to preserve the mask area (that is why you can see the mask outline). The inpaint specific checkpoints are intentionally trained to generate higher quality inpainted images, and that includes creating a more natural transition between the masked and unmasked areas. As a result, these checkpoints are more likely to change your unmasked area.
-
-If preserving the unmasked area is important for your task, you can use the code below to force the unmasked area of an image to remain the same at the expense of some more unnatural transitions between the masked and unmasked areas.
-
-```py
-import PIL
-import numpy as np
-import torch
-
-from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
-
-device = "cuda"
-pipeline = AutoPipelineForInpainting.from_pretrained(
- "runwayml/stable-diffusion-inpainting",
- torch_dtype=torch.float16,
-)
-pipeline = pipeline.to(device)
-
-img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
-mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
-
-init_image = load_image(img_url).resize((512, 512))
-mask_image = load_image(mask_url).resize((512, 512))
-
-prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
-repainted_image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
-repainted_image.save("repainted_image.png")
-
-# Convert mask to grayscale NumPy array
-mask_image_arr = np.array(mask_image.convert("L"))
-# Add a channel dimension to the end of the grayscale mask
-mask_image_arr = mask_image_arr[:, :, None]
-# Binarize the mask: 1s correspond to the pixels which are repainted
-mask_image_arr = mask_image_arr.astype(np.float32) / 255.0
-mask_image_arr[mask_image_arr < 0.5] = 0
-mask_image_arr[mask_image_arr >= 0.5] = 1
-
-# Take the masked pixels from the repainted image and the unmasked pixels from the initial image
-unmasked_unchanged_image_arr = (1 - mask_image_arr) * init_image + mask_image_arr * repainted_image
-unmasked_unchanged_image = PIL.Image.fromarray(unmasked_unchanged_image_arr.round().astype("uint8"))
-unmasked_unchanged_image.save("force_unmasked_unchanged.png")
-make_image_grid([init_image, mask_image, repainted_image, unmasked_unchanged_image], rows=2, cols=2)
-```
-
## Configure pipeline parameters
Image features - like quality and "creativity" - are dependent on pipeline parameters. Knowing what these parameters do is important for getting the results you want. Let's take a look at the most important parameters and see how changing them affects the output.
@@ -375,22 +186,20 @@ Image features - like quality and "creativity" - are dependent on pipeline param
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.6).images[0]
-make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
@@ -420,22 +229,20 @@ You can use `strength` and `guidance_scale` together for more control over how e
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, guidance_scale=2.5).images[0]
-make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
@@ -460,23 +267,22 @@ A negative prompt assumes the opposite role of a prompt; it guides the model awa
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
negative_prompt = "bad architecture, unstable, poor details, blurry"
image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]
-make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+image
```
@@ -486,6 +292,50 @@ make_image_grid([init_image, mask_image, image], rows=1, cols=3)
+## Preserve unmasked areas
+
+The [`AutoPipelineForInpainting`] (and other inpainting pipelines) generally changes the unmasked parts of an image to create a more natural transition between the masked and unmasked region. If this behavior is undesirable, you can force the unmasked area to remain the same. However, forcing the unmasked portion of the image to remain the same may result in some unusual transitions between the unmasked and masked areas.
+
+```py
+import PIL
+import numpy as np
+import torch
+
+from diffusers import AutoPipelineForInpainting
+from diffusers.utils import load_image
+
+device = "cuda"
+pipeline = AutoPipelineForInpainting.from_pretrained(
+ "runwayml/stable-diffusion-inpainting",
+ torch_dtype=torch.float16,
+)
+pipeline = pipeline.to(device)
+
+img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+init_image = load_image(img_url).resize((512, 512))
+mask_image = load_image(mask_url).resize((512, 512))
+
+prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+repainted_image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+repainted_image.save("repainted_image.png")
+
+# Convert mask to grayscale NumPy array
+mask_image_arr = np.array(mask_image.convert("L"))
+# Add a channel dimension to the end of the grayscale mask
+mask_image_arr = mask_image_arr[:, :, None]
+# Binarize the mask: 1s correspond to the pixels which are repainted
+mask_image_arr = mask_image_arr.astype(np.float32) / 255.0
+mask_image_arr[mask_image_arr < 0.5] = 0
+mask_image_arr[mask_image_arr >= 0.5] = 1
+
+# Take the masked pixels from the repainted image and the unmasked pixels from the initial image
+unmasked_unchanged_image_arr = (1 - mask_image_arr) * init_image + mask_image_arr * repainted_image
+unmasked_unchanged_image = PIL.Image.fromarray(unmasked_unchanged_image_arr.round().astype("uint8"))
+unmasked_unchanged_image.save("force_unmasked_unchanged.png")
+```
+
## Chained inpainting pipelines
[`AutoPipelineForInpainting`] can be chained with other 🤗 Diffusers pipelines to edit their outputs. This is often useful for improving the output quality from your other diffusion pipelines, and if you're using multiple pipelines, it can be more memory-efficient to chain them together to keep the outputs in latent space and reuse the same pipeline components.
@@ -499,37 +349,35 @@ Start with the text-to-image pipeline to create a castle:
```py
import torch
from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
pipeline = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-text2image = pipeline("concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k").images[0]
+image = pipeline("concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k").images[0]
```
Load the mask image of the output from above:
```py
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_text-chain-mask.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_text-chain-mask.png").convert("RGB")
```
And let's inpaint the masked area with a waterfall:
```py
pipeline = AutoPipelineForInpainting.from_pretrained(
- "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
-)
+ "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
prompt = "digital painting of a fantasy waterfall, cloudy"
-image = pipeline(prompt=prompt, image=text2image, mask_image=mask_image).images[0]
-make_image_grid([text2image, mask_image, image], rows=1, cols=3)
+image = pipeline(prompt=prompt, image=image, mask_image=mask_image).images[0]
+image
```
@@ -543,6 +391,7 @@ make_image_grid([text2image, mask_image, image], rows=1, cols=3)
+
### Inpaint-to-image-to-image
You can also chain an inpainting pipeline before another pipeline like image-to-image or an upscaler to improve the quality.
@@ -552,24 +401,23 @@ Begin by inpainting an image:
```py
import torch
from diffusers import AutoPipelineForInpainting, AutoPipelineForImage2Image
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
-image_inpainting = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
# resize image to 1024x1024 for SDXL
-image_inpainting = image_inpainting.resize((1024, 1024))
+image = image.resize((1024, 1024))
```
Now let's pass the image to another inpainting pipeline with SDXL's refiner model to enhance the image details and quality:
@@ -577,12 +425,11 @@ Now let's pass the image to another inpainting pipeline with SDXL's refiner mode
```py
pipeline = AutoPipelineForInpainting.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16"
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-image = pipeline(prompt=prompt, image=image_inpainting, mask_image=mask_image, output_type="latent").images[0]
+image = pipeline(prompt=prompt, image=image, mask_image=mask_image, output_type="latent").images[0]
```
@@ -595,11 +442,9 @@ Finally, you can pass this image to an image-to-image pipeline to put the finish
```py
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline)
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
image = pipeline(prompt=prompt, image=image).images[0]
-make_image_grid([init_image, mask_image, image_inpainting, image], rows=2, cols=2)
```
@@ -632,21 +477,18 @@ Once you've generated the embeddings, pass them to the `prompt_embeds` (and `neg
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import make_image_grid
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16,
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-image = pipeline(prompt_embeds=prompt_embeds, # generated from Compel
- negative_prompt_embeds=negative_prompt_embeds, # generated from Compel
+image = pipeline(prompt_emebds=prompt_embeds, # generated from Compel
+ negative_prompt_embeds, # generated from Compel
image=init_image,
mask_image=mask_image
).images[0]
-make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
### ControlNet
@@ -659,7 +501,7 @@ For example, let's condition an image with a ControlNet pretrained on inpaint im
import torch
import numpy as np
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
# load ControlNet
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16, variant="fp16")
@@ -667,14 +509,13 @@ controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpai
# pass ControlNet to the pipeline
pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16"
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
# prepare control image
def make_inpaint_condition(init_image, mask_image):
@@ -695,7 +536,7 @@ Now generate an image from the base, mask and control images. You'll notice feat
```py
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, control_image=control_image).images[0]
-make_image_grid([init_image, mask_image, PIL.Image.fromarray(np.uint8(control_image[0][0])).convert('RGB'), image], rows=2, cols=2)
+image
```
You can take this a step further and chain it with an image-to-image pipeline to apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion):
@@ -705,16 +546,15 @@ from diffusers import AutoPipelineForImage2Image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
-)
+).to("cuda")
pipeline.enable_model_cpu_offload()
-# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
prompt = "elden ring style castle" # include the token "elden ring style" in the prompt
negative_prompt = "bad architecture, deformed, disfigured, poor details"
-image_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image).images[0]
-make_image_grid([init_image, mask_image, image, image_elden_ring], rows=2, cols=2)
+image = pipeline(prompt, negative_prompt=negative_prompt, image=image).images[0]
+image
```
@@ -734,19 +574,19 @@ make_image_grid([init_image, mask_image, image, image_elden_ring], rows=2, cols=
## Optimize
-It can be difficult and slow to run diffusion models if you're resource constrained, but it doesn't have to be with a few optimization tricks. One of the biggest (and easiest) optimizations you can enable is switching to memory-efficient attention. If you're using PyTorch 2.0, [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) is automatically enabled and you don't need to do anything else. For non-PyTorch 2.0 users, you can install and use [xFormers](../optimization/xformers)'s implementation of memory-efficient attention. Both options reduce memory usage and accelerate inference.
+It can be difficult and slow to run diffusion models if you're resource constrained, but it dosen't have to be with a few optimization tricks. One of the biggest (and easiest) optimizations you can enable is switching to memory-efficient attention. If you're using PyTorch 2.0, [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) is automatically enabled and you don't need to do anything else. For non-PyTorch 2.0 users, you can install and use [xFormers](../optimization/xformers)'s implementation of memory-efficient attention. Both options reduce memory usage and accelerate inference.
-You can also offload the model to the CPU to save even more memory:
+You can also offload the model to the GPU to save even more memory:
```diff
+ pipeline.enable_xformers_memory_efficient_attention()
+ pipeline.enable_model_cpu_offload()
```
-To speed-up your inference code even more, use [`torch_compile`](../optimization/torch2.0#torchcompile). You should wrap `torch.compile` around the most intensive component in the pipeline which is typically the UNet:
+To speed-up your inference code even more, use [`torch_compile`](../optimization/torch2.0#torch.compile). You should wrap `torch.compile` around the most intensive component in the pipeline which is typically the UNet:
```py
-pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
+pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
```
-Learn more in the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
+Learn more in the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/loading.md b/docs/source/en/using-diffusers/loading.md
index d9e19a5bdd2a..3fb11ac92c1f 100644
--- a/docs/source/en/using-diffusers/loading.md
+++ b/docs/source/en/using-diffusers/loading.md
@@ -29,11 +29,11 @@ This guide will show you how to load:
-💡 Skip to the [DiffusionPipeline explained](#diffusionpipeline-explained) section if you are interested in learning in more detail about how the [`DiffusionPipeline`] class works.
+💡 Skip to the [DiffusionPipeline explained](#diffusionpipeline-explained) section if you interested in learning in more detail about how the [`DiffusionPipeline`] class works.
-The [`DiffusionPipeline`] class is the simplest and most generic way to load the latest trending diffusion model from the [Hub](https://huggingface.co/models?library=diffusers&sort=trending). The [`DiffusionPipeline.from_pretrained`] method automatically detects the correct pipeline class from the checkpoint, downloads, and caches all the required configuration and weight files, and returns a pipeline instance ready for inference.
+The [`DiffusionPipeline`] class is the simplest and most generic way to load any diffusion model from the [Hub](https://huggingface.co/models?library=diffusers). The [`DiffusionPipeline.from_pretrained`] method automatically detects the correct pipeline class from the checkpoint, downloads and caches all the required configuration and weight files, and returns a pipeline instance ready for inference.
```python
from diffusers import DiffusionPipeline
@@ -42,7 +42,7 @@ repo_id = "runwayml/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(repo_id, use_safetensors=True)
```
-You can also load a checkpoint with its specific pipeline class. The example above loaded a Stable Diffusion model; to get the same result, use the [`StableDiffusionPipeline`] class:
+You can also load a checkpoint with it's specific pipeline class. The example above loaded a Stable Diffusion model; to get the same result, use the [`StableDiffusionPipeline`] class:
```python
from diffusers import StableDiffusionPipeline
@@ -51,7 +51,7 @@ repo_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(repo_id, use_safetensors=True)
```
-A checkpoint (such as [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) or [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)) may also be used for more than one task, like text-to-image or image-to-image. To differentiate what task you want to use the checkpoint for, you have to load it directly with its corresponding task-specific pipeline class:
+A checkpoint (such as [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) or [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)) may also be used for more than one task, like text-to-image or image-to-image. To differentiate what task you want to use the checkpoint for, you have to load it directly with it's corresponding task-specific pipeline class:
```python
from diffusers import StableDiffusionImg2ImgPipeline
@@ -103,10 +103,12 @@ Let's use the [`SchedulerMixin.from_pretrained`] method to replace the default [
Then you can pass the new [`EulerDiscreteScheduler`] instance to the `scheduler` argument in [`DiffusionPipeline`]:
```python
-from diffusers import DiffusionPipeline, EulerDiscreteScheduler
+from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler
repo_id = "runwayml/stable-diffusion-v1-5"
+
scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
+
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler, use_safetensors=True)
```
@@ -119,9 +121,6 @@ from diffusers import DiffusionPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, safety_checker=None, use_safetensors=True)
-"""
-You have disabled the safety checker for
by passing `safety_checker=None`. Ensure that you abide by the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend keeping the safety filter enabled in all public-facing circumstances, disabling it only for use cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
-"""
```
### Reuse components across pipelines
@@ -164,10 +163,10 @@ stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(
## Checkpoint variants
-A checkpoint variant is usually a checkpoint whose weights are:
+A checkpoint variant is usually a checkpoint where it's weights are:
- Stored in a different floating point type for lower precision and lower storage, such as [`torch.float16`](https://pytorch.org/docs/stable/tensors.html#data-types), because it only requires half the bandwidth and storage to download. You can't use this variant if you're continuing training or using a CPU.
-- Non-exponential mean averaged (EMA) weights, which shouldn't be used for inference. You should use these to continue fine-tuning a model.
+- Non-exponential mean averaged (EMA) weights which shouldn't be used for inference. You should use these to continue finetuning a model.
@@ -175,7 +174,7 @@ A checkpoint variant is usually a checkpoint whose weights are:
-Otherwise, a variant is **identical** to the original checkpoint. They have exactly the same serialization format (like [Safetensors](./using_safetensors)), model structure, and weights that have identical tensor shapes.
+Otherwise, a variant is **identical** to the original checkpoint. They have exactly the same serialization format (like [Safetensors](./using_safetensors)), model structure, and weights have identical tensor shapes.
| **checkpoint type** | **weight name** | **argument for loading weights** |
|---------------------|-------------------------------------|----------------------------------|
@@ -203,7 +202,7 @@ stable_diffusion = DiffusionPipeline.from_pretrained(
)
```
-To save a checkpoint stored in a different floating-point type or as a non-EMA variant, use the [`DiffusionPipeline.save_pretrained`] method and specify the `variant` argument. You should try and save a variant to the same folder as the original checkpoint, so you can load both from the same folder:
+To save a checkpoint stored in a different floating point type or as a non-EMA variant, use the [`DiffusionPipeline.save_pretrained`] method and specify the `variant` argument. You should try and save a variant to the same folder as the original checkpoint, so you can load both from the same folder:
```python
from diffusers import DiffusionPipeline
@@ -232,7 +231,7 @@ TODO(Patrick) - Make sure to uncomment this part as soon as things are deprecate
#### Using `revision` to load pipeline variants is deprecated
-Previously the `revision` argument of [`DiffusionPipeline.from_pretrained`] was heavily used to
+Previously the `revision` argument of [`DiffusionPipeline.from_pretrained`] was heavily used to
load model variants, e.g.:
```python
@@ -247,8 +246,8 @@ The above example is therefore deprecated and won't be supported anymore for `di
-If you load diffusers pipelines or models with `revision="fp16"` or `revision="non_ema"`,
-please make sure to update the code and use `variant="fp16"` or `variation="non_ema"` respectively
+If you load diffusers pipelines or models with `revision="fp16"` or `revision="non_ema"`,
+please make sure to update to code and use `variant="fp16"` or `variation="non_ema"` respectively
instead.
@@ -256,7 +255,7 @@ instead.
## Models
-Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of re-downloading them.
+Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of redownloading them.
Models can be loaded from a subfolder with the `subfolder` argument. For example, the model weights for `runwayml/stable-diffusion-v1-5` are stored in the [`unet`](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/unet) subfolder:
@@ -282,9 +281,9 @@ You can also load and save model variants by specifying the `variant` argument i
from diffusers import UNet2DConditionModel
model = UNet2DConditionModel.from_pretrained(
- "runwayml/stable-diffusion-v1-5", subfolder="unet", variant="non_ema", use_safetensors=True
+ "runwayml/stable-diffusion-v1-5", subfolder="unet", variant="non-ema", use_safetensors=True
)
-model.save_pretrained("./local-unet", variant="non_ema")
+model.save_pretrained("./local-unet", variant="non-ema")
```
## Schedulers
@@ -292,7 +291,7 @@ model.save_pretrained("./local-unet", variant="non_ema")
Schedulers are loaded from the [`SchedulerMixin.from_pretrained`] method, and unlike models, schedulers are **not parameterized** or **trained**; they are defined by a configuration file.
Loading schedulers does not consume any significant amount of memory and the same configuration file can be used for a variety of different schedulers.
-For example, the following schedulers are compatible with [`StableDiffusionPipeline`], which means you can load the same scheduler configuration file in any of these classes:
+For example, the following schedulers are compatible with [`StableDiffusionPipeline`] which means you can load the same scheduler configuration file in any of these classes:
```python
from diffusers import StableDiffusionPipeline
@@ -301,8 +300,8 @@ from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
- EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
+ EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
@@ -325,9 +324,9 @@ pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm, use_s
As a class method, [`DiffusionPipeline.from_pretrained`] is responsible for two things:
- Download the latest version of the folder structure required for inference and cache it. If the latest folder structure is available in the local cache, [`DiffusionPipeline.from_pretrained`] reuses the cache and won't redownload the files.
-- Load the cached weights into the correct pipeline [class](../api/pipelines/overview#diffusers-summary) - retrieved from the `model_index.json` file - and return an instance of it.
+- Load the cached weights into the correct pipeline [class](./api/pipelines/overview#diffusers-summary) - retrieved from the `model_index.json` file - and return an instance of it.
-The pipelines' underlying folder structure corresponds directly with their class instances. For example, the [`StableDiffusionPipeline`] corresponds to the folder structure in [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5).
+The pipelines underlying folder structure corresponds directly with their class instances. For example, the [`StableDiffusionPipeline`] corresponds to the folder structure in [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5).
```python
from diffusers import DiffusionPipeline
@@ -339,13 +338,13 @@ print(pipeline)
You'll see pipeline is an instance of [`StableDiffusionPipeline`], which consists of seven components:
-- `"feature_extractor"`: a [`~transformers.CLIPImageProcessor`] from 🤗 Transformers.
+- `"feature_extractor"`: a [`~transformers.CLIPFeatureExtractor`] from 🤗 Transformers.
- `"safety_checker"`: a [component](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32) for screening against harmful content.
- `"scheduler"`: an instance of [`PNDMScheduler`].
- `"text_encoder"`: a [`~transformers.CLIPTextModel`] from 🤗 Transformers.
- `"tokenizer"`: a [`~transformers.CLIPTokenizer`] from 🤗 Transformers.
- `"unet"`: an instance of [`UNet2DConditionModel`].
-- `"vae"`: an instance of [`AutoencoderKL`].
+- `"vae"` an instance of [`AutoencoderKL`].
```json
StableDiffusionPipeline {
@@ -380,7 +379,7 @@ StableDiffusionPipeline {
}
```
-Compare the components of the pipeline instance to the [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) folder structure, and you'll see there is a separate folder for each of the components in the repository:
+Compare the components of the pipeline instance to the [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) folder structure, and you'll see there is a separate folder for each of the components in the repository:
```
.
@@ -389,18 +388,12 @@ Compare the components of the pipeline instance to the [`runwayml/stable-diffusi
├── model_index.json
├── safety_checker
│ ├── config.json
-| ├── model.fp16.safetensors
-│ ├── model.safetensors
-│ ├── pytorch_model.bin
-| └── pytorch_model.fp16.bin
+│ └── pytorch_model.bin
├── scheduler
│ └── scheduler_config.json
├── text_encoder
│ ├── config.json
-| ├── model.fp16.safetensors
-│ ├── model.safetensors
-│ |── pytorch_model.bin
-| └── pytorch_model.fp16.bin
+│ └── pytorch_model.bin
├── tokenizer
│ ├── merges.txt
│ ├── special_tokens_map.json
@@ -409,17 +402,9 @@ Compare the components of the pipeline instance to the [`runwayml/stable-diffusi
├── unet
│ ├── config.json
│ ├── diffusion_pytorch_model.bin
-| |── diffusion_pytorch_model.fp16.bin
-│ |── diffusion_pytorch_model.f16.safetensors
-│ |── diffusion_pytorch_model.non_ema.bin
-│ |── diffusion_pytorch_model.non_ema.safetensors
-│ └── diffusion_pytorch_model.safetensors
-|── vae
-. ├── config.json
-. ├── diffusion_pytorch_model.bin
- ├── diffusion_pytorch_model.fp16.bin
- ├── diffusion_pytorch_model.fp16.safetensors
- └── diffusion_pytorch_model.safetensors
+└── vae
+ ├── config.json
+ ├── diffusion_pytorch_model.bin
```
You can access each of the components of the pipeline as an attribute to view its configuration:
@@ -439,11 +424,10 @@ CLIPTokenizer(
"unk_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
"pad_token": "<|endoftext|>",
},
- clean_up_tokenization_spaces=True
)
```
-Every pipeline expects a [`model_index.json`](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json) file that tells the [`DiffusionPipeline`]:
+Every pipeline expects a `model_index.json` file that tells the [`DiffusionPipeline`]:
- which pipeline class to load from `_class_name`
- which version of 🧨 Diffusers was used to create the model in `_diffusers_version`
diff --git a/docs/source/en/using-diffusers/loading_overview.md b/docs/source/en/using-diffusers/loading_overview.md
index b36fdb77e6dd..df870505219b 100644
--- a/docs/source/en/using-diffusers/loading_overview.md
+++ b/docs/source/en/using-diffusers/loading_overview.md
@@ -14,4 +14,4 @@ specific language governing permissions and limitations under the License.
🧨 Diffusers offers many pipelines, models, and schedulers for generative tasks. To make loading these components as simple as possible, we provide a single and unified method - `from_pretrained()` - that loads any of these components from either the Hugging Face [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) or your local machine. Whenever you load a pipeline or model, the latest files are automatically downloaded and cached so you can quickly reuse them next time without redownloading the files.
-This section will show you everything you need to know about loading pipelines, how to load different components in a pipeline, how to load checkpoint variants, and how to load community pipelines. You'll also learn how to load schedulers and compare the speed and quality trade-offs of using different schedulers. Finally, you'll see how to convert and load KerasCV checkpoints so you can use them in PyTorch with 🧨 Diffusers.
+This section will show you everything you need to know about loading pipelines, how to load different components in a pipeline, how to load checkpoint variants, and how to load community pipelines. You'll also learn how to load schedulers and compare the speed and quality trade-offs of using different schedulers. Finally, you'll see how to convert and load KerasCV checkpoints so you can use them in PyTorch with 🧨 Diffusers.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/other-formats.md b/docs/source/en/using-diffusers/other-formats.md
index 6f8e00d1e396..c2f10ff79637 100644
--- a/docs/source/en/using-diffusers/other-formats.md
+++ b/docs/source/en/using-diffusers/other-formats.md
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
[[open-in-colab]]
-Stable Diffusion models are available in different formats depending on the framework they're trained and saved with, and where you download them from. Converting these formats for use in 🤗 Diffusers allows you to use all the features supported by the library, such as [using different schedulers](schedulers) for inference, [building your custom pipeline](write_own_pipeline), and a variety of techniques and methods for [optimizing inference speed](../optimization/opt_overview).
+Stable Diffusion models are available in different formats depending on the framework they're trained and saved with, and where you download them from. Converting these formats for use in 🤗 Diffusers allows you to use all the features supported by the library, such as [using different schedulers](schedulers) for inference, [building your custom pipeline](write_own_pipeline), and a variety of techniques and methods for [optimizing inference speed](./optimization/opt_overview).
@@ -28,17 +28,17 @@ This guide will show you how to convert other Stable Diffusion formats to be com
The checkpoint - or `.ckpt` - format is commonly used to store and save models. The `.ckpt` file contains the entire model and is typically several GBs in size. While you can load and use a `.ckpt` file directly with the [`~StableDiffusionPipeline.from_single_file`] method, it is generally better to convert the `.ckpt` file to 🤗 Diffusers so both formats are available.
-There are two options for converting a `.ckpt` file: use a Space to convert the checkpoint or convert the `.ckpt` file with a script.
+There are two options for converting a `.ckpt` file; use a Space to convert the checkpoint or convert the `.ckpt` file with a script.
### Convert with a Space
The easiest and most convenient way to convert a `.ckpt` file is to use the [SD to Diffusers](https://huggingface.co/spaces/diffusers/sd-to-diffusers) Space. You can follow the instructions on the Space to convert the `.ckpt` file.
-This approach works well for basic models, but it may struggle with more customized models. You'll know the Space failed if it returns an empty pull request or error. In this case, you can try converting the `.ckpt` file with a script.
+This approach works well for basic models, but it may struggle with more customized models. You'll know the Space failed if it returns an empty pull request or error. In this case, you can try converting the `.ckpt` file with a script.
### Convert with a script
-🤗 Diffusers provides a [conversion script](https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py) for converting `.ckpt` files. This approach is more reliable than the Space above.
+🤗 Diffusers provides a [conversion script](https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py) for converting `.ckpt` files. This approach is more reliable than the Space above.
Before you start, make sure you have a local clone of 🤗 Diffusers to run the script and log in to your Hugging Face account so you can open pull requests and push your converted model to the Hub.
@@ -86,11 +86,11 @@ git push origin pr/13:refs/pr/13
-🧪 This is an experimental feature. Only Stable Diffusion v1 checkpoints are supported by the Convert KerasCV Space at the moment.
+🧪 This is an experimental feature. Only Stable Diffusion v1 checkpoints are supported by the Convert KerasCV Space at the moment.
-[KerasCV](https://keras.io/keras_cv/) supports training for [Stable Diffusion](https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/stable_diffusion) v1 and v2. However, it offers limited support for experimenting with Stable Diffusion models for inference and deployment whereas 🤗 Diffusers has a more complete set of features for this purpose, such as different [noise schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers), [flash attention](https://huggingface.co/docs/diffusers/optimization/xformers), and [other
+[KerasCV](https://keras.io/keras_cv/) supports training for [Stable Diffusion](https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/stable_diffusion) v1 and v2. However, it offers limited support for experimenting with Stable Diffusion models for inference and deployment whereas 🤗 Diffusers has a more complete set of features for this purpose, such as different [noise schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers), [flash attention](https://huggingface.co/docs/diffusers/optimization/xformers), and [other
optimization techniques](https://huggingface.co/docs/diffusers/optimization/fp16).
The [Convert KerasCV](https://huggingface.co/spaces/sayakpaul/convert-kerascv-sd-diffusers) Space converts `.pb` or `.h5` files to PyTorch, and then wraps them in a [`StableDiffusionPipeline`] so it is ready for inference. The converted checkpoint is stored in a repository on the Hugging Face Hub.
@@ -116,7 +116,7 @@ pipeline = DiffusionPipeline.from_pretrained(
)
```
-Then, you can generate an image like:
+Then you can generate an image like:
```py
from diffusers import DiffusionPipeline
@@ -136,41 +136,53 @@ image = pipeline(prompt, num_inference_steps=50).images[0]
[Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) (A1111) is a popular web UI for Stable Diffusion that supports model sharing platforms like [Civitai](https://civitai.com/). Models trained with the Low-Rank Adaptation (LoRA) technique are especially popular because they're fast to train and have a much smaller file size than a fully finetuned model. 🤗 Diffusers supports loading A1111 LoRA checkpoints with [`~loaders.LoraLoaderMixin.load_lora_weights`]:
```py
-from diffusers import StableDiffusionXLPipeline
+from diffusers import DiffusionPipeline, UniPCMultistepScheduler
import torch
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16, variant="fp16"
+pipeline = DiffusionPipeline.from_pretrained(
+ "andite/anything-v4.0", torch_dtype=torch.float16, safety_checker=None
).to("cuda")
+pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
```
-Download a LoRA checkpoint from Civitai; this example uses the [Blueprintify SD XL 1.0](https://civitai.com/models/150986/blueprintify-sd-xl-10) checkpoint, but feel free to try out any LoRA checkpoint!
+Download a LoRA checkpoint from Civitai; this example uses the [Howls Moving Castle,Interior/Scenery LoRA (Ghibli Stlye)](https://civitai.com/models/14605?modelVersionId=19998) checkpoint, but feel free to try out any LoRA checkpoint!
```py
# uncomment to download the safetensor weights
-#!wget https://civitai.com/api/download/models/168776 -O blueprintify.safetensors
+#!wget https://civitai.com/api/download/models/19998 -O howls_moving_castle.safetensors
```
Load the LoRA checkpoint into the pipeline with the [`~loaders.LoraLoaderMixin.load_lora_weights`] method:
```py
-pipeline.load_lora_weights(".", weight_name="blueprintify.safetensors")
+pipeline.load_lora_weights(".", weight_name="howls_moving_castle.safetensors")
```
Now you can use the pipeline to generate images:
```py
-prompt = "bl3uprint, a highly detailed blueprint of the empire state building, explaining how to build all parts, many txt, blueprint grid backdrop"
+prompt = "masterpiece, illustration, ultra-detailed, cityscape, san francisco, golden gate bridge, california, bay area, in the snow, beautiful detailed starry sky"
negative_prompt = "lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture"
-image = pipeline(
+images = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
+ width=512,
+ height=512,
+ num_inference_steps=25,
+ num_images_per_prompt=4,
generator=torch.manual_seed(0),
-).images[0]
-image
+).images
+```
+
+Display the images:
+
+```py
+from diffusers.utils import make_image_grid
+
+make_image_grid(images, 2, 2)
```
-

+
diff --git a/docs/source/en/using-diffusers/pipeline_overview.md b/docs/source/en/using-diffusers/pipeline_overview.md
index 292ce51d322a..4ee25b51dc6f 100644
--- a/docs/source/en/using-diffusers/pipeline_overview.md
+++ b/docs/source/en/using-diffusers/pipeline_overview.md
@@ -14,4 +14,4 @@ specific language governing permissions and limitations under the License.
A pipeline is an end-to-end class that provides a quick and easy way to use a diffusion system for inference by bundling independently trained models and schedulers together. Certain combinations of models and schedulers define specific pipeline types, like [`StableDiffusionXLPipeline`] or [`StableDiffusionControlNetPipeline`], with specific capabilities. All pipeline types inherit from the base [`DiffusionPipeline`] class; pass it any checkpoint, and it'll automatically detect the pipeline type and load the necessary components.
-This section demonstrates how to use specific pipelines such as Stable Diffusion XL, ControlNet, and DiffEdit. You'll also learn how to use a distilled version of the Stable Diffusion model to speed up inference, how to create reproducible pipelines, and how to use and contribute community pipelines.
+This section introduces you to some of the more complex pipelines like Stable Diffusion XL, ControlNet, and DiffEdit, which require additional inputs. You'll also learn how to use a distilled version of the Stable Diffusion model to speed up inference, how to control randomness on your hardware when generating images, and how to create a community pipeline for a custom task like generating images from speech.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/push_to_hub.md b/docs/source/en/using-diffusers/push_to_hub.md
index 58598c3bc443..468386031768 100644
--- a/docs/source/en/using-diffusers/push_to_hub.md
+++ b/docs/source/en/using-diffusers/push_to_hub.md
@@ -1,15 +1,3 @@
-
-
# Push files to the Hub
[[open-in-colab]]
@@ -32,7 +20,7 @@ notebook_login()
## Models
-To push a model to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the model to be stored on the Hub:
+To push a model to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specfiy the repository id of the model to be stored on the Hub:
```py
from diffusers import ControlNetModel
@@ -48,7 +36,7 @@ controlnet = ControlNetModel(
controlnet.push_to_hub("my-controlnet-model")
```
-For models, you can also specify the [*variant*](loading#checkpoint-variants) of the weights to push to the Hub. For example, to push `fp16` weights:
+For model's, you can also specify the [*variant*](loading#checkpoint-variants) of the weights to push to the Hub. For example, to push `fp16` weights:
```py
controlnet.push_to_hub("my-controlnet-model", variant="fp16")
@@ -64,7 +52,7 @@ model = ControlNetModel.from_pretrained("your-namespace/my-controlnet-model")
## Scheduler
-To push a scheduler to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the scheduler to be stored on the Hub:
+To push a scheduler to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specfiy the repository id of the scheduler to be stored on the Hub:
```py
from diffusers import DDIMScheduler
@@ -171,13 +159,13 @@ pipeline = StableDiffusionPipeline.from_pretrained("your-namespace/my-pipeline")
Set `private=True` in the [`~diffusers.utils.PushToHubMixin.push_to_hub`] function to keep your model, scheduler, or pipeline files private:
```py
-controlnet.push_to_hub("my-controlnet-model-private", private=True)
+controlnet.push_to_hub("my-controlnet-model", private=True)
```
-Private repositories are only visible to you, and other users won't be able to clone the repository and your repository won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for.`
+Private repositories are only visible to you, and other users won't be able to clone the repository and your repository won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Repo not found error.`
-To load a model, scheduler, or pipeline from private or gated repositories, set `use_auth_token=True`:
+To load a model, scheduler, or pipeline from a private or gated repositories, set `use_auth_token=True`:
```py
-model = ControlNetModel.from_pretrained("your-namespace/my-controlnet-model-private", use_auth_token=True)
-```
+model = ControlNet.from_pretrained("your-namespace/my-controlnet-model", use_auth_token=True)
+```
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/reproducibility.md b/docs/source/en/using-diffusers/reproducibility.md
index 5bc1d02b14d4..0da760f0192d 100644
--- a/docs/source/en/using-diffusers/reproducibility.md
+++ b/docs/source/en/using-diffusers/reproducibility.md
@@ -55,7 +55,7 @@ But if you need to reliably generate the same image, that'll depend on whether y
### CPU
-To generate reproducible results on a CPU, you'll need to use a PyTorch [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) and set a seed:
+To generate reproducible results on a CPU, you'll need to use a PyTorch [`Generator`](https://pytorch.org/docs/stable/generated/torch.randn.html) and set a seed:
```python
import torch
@@ -83,7 +83,7 @@ If you run this code example on your specific hardware and PyTorch version, you
💡 It might be a bit unintuitive at first to pass `Generator` objects to the pipeline instead of
just integer values representing the seed, but this is the recommended design when dealing with
-probabilistic models in PyTorch, as `Generator`s are *random states* that can be
+probabilistic models in PyTorch as `Generator`'s are *random states* that can be
passed to multiple pipelines in a sequence.
@@ -153,13 +153,12 @@ exactly the same hardware and PyTorch version for full reproducibility.
You can also configure PyTorch to use deterministic algorithms to create a reproducible pipeline. However, you should be aware that deterministic algorithms may be slower than nondeterministic ones and you may observe a decrease in performance. But if reproducibility is important to you, then this is the way to go!
-Nondeterministic behavior occurs when operations are launched in more than one CUDA stream. To avoid this, set the environment variable [`CUBLAS_WORKSPACE_CONFIG`](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility) to `:16:8` to only use one buffer size during runtime.
+Nondeterministic behavior occurs when operations are launched in more than one CUDA stream. To avoid this, set the environment varibale [`CUBLAS_WORKSPACE_CONFIG`](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility) to `:16:8` to only use one buffer size during runtime.
PyTorch typically benchmarks multiple algorithms to select the fastest one, but if you want reproducibility, you should disable this feature because the benchmark may select different algorithms each time. Lastly, pass `True` to [`torch.use_deterministic_algorithms`](https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html) to enable deterministic algorithms.
```py
import os
-import torch
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
@@ -172,6 +171,7 @@ Now when you run the same pipeline twice, you'll get identical results.
```py
import torch
from diffusers import DDIMScheduler, StableDiffusionPipeline
+import numpy as np
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_safetensors=True).to("cuda")
@@ -186,6 +186,6 @@ result1 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="
g.manual_seed(0)
result2 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="latent").images
-print("L_inf dist =", abs(result1 - result2).max())
-"L_inf dist = tensor(0., device='cuda:0')"
-```
+print("L_inf dist = ", abs(result1 - result2).max())
+"L_inf dist = tensor(0., device='cuda:0')"
+```
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/reusing_seeds.md b/docs/source/en/using-diffusers/reusing_seeds.md
index d2638b469e30..7cbaf2643202 100644
--- a/docs/source/en/using-diffusers/reusing_seeds.md
+++ b/docs/source/en/using-diffusers/reusing_seeds.md
@@ -16,7 +16,7 @@ specific language governing permissions and limitations under the License.
A common way to improve the quality of generated images is with *deterministic batch generation*, generate a batch of images and select one image to improve with a more detailed prompt in a second round of inference. The key is to pass a list of [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html#generator)'s to the pipeline for batched image generation, and tie each `Generator` to a seed so you can reuse it for an image.
-Let's use [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) for example, and generate several versions of the following prompt:
+Let's use [`runwayml/stable-diffusion-v1-5`](runwayml/stable-diffusion-v1-5) for example, and generate several versions of the following prompt:
```py
prompt = "Labrador in the style of Vermeer"
@@ -25,27 +25,27 @@ prompt = "Labrador in the style of Vermeer"
Instantiate a pipeline with [`DiffusionPipeline.from_pretrained`] and place it on a GPU (if available):
```python
-import torch
-from diffusers import DiffusionPipeline
-from diffusers.utils import make_image_grid
-
-pipe = DiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
-)
-pipe = pipe.to("cuda")
+>>> from diffusers import DiffusionPipeline
+
+>>> pipe = DiffusionPipeline.from_pretrained(
+... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+... )
+>>> pipe = pipe.to("cuda")
```
-Now, define four different `Generator`s and assign each `Generator` a seed (`0` to `3`) so you can reuse a `Generator` later for a specific image:
+Now, define four different `Generator`'s and assign each `Generator` a seed (`0` to `3`) so you can reuse a `Generator` later for a specific image:
```python
-generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
+>>> import torch
+
+>>> generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
```
Generate the images and have a look:
```python
-images = pipe(prompt, generator=generator, num_images_per_prompt=4).images
-make_image_grid(images, rows=2, cols=2)
+>>> images = pipe(prompt, generator=generator, num_images_per_prompt=4).images
+>>> images
```

@@ -60,8 +60,8 @@ generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]
Create four generators with seed `0`, and generate another batch of images, all of which should look like the first image from the previous round!
```python
-images = pipe(prompt, generator=generator).images
-make_image_grid(images, rows=2, cols=2)
+>>> images = pipe(prompt, generator=generator).images
+>>> images
```

diff --git a/docs/source/en/using-diffusers/schedulers.md b/docs/source/en/using-diffusers/schedulers.md
index 6b5d8da465d8..c791b47b7832 100644
--- a/docs/source/en/using-diffusers/schedulers.md
+++ b/docs/source/en/using-diffusers/schedulers.md
@@ -14,14 +14,14 @@ specific language governing permissions and limitations under the License.
[[open-in-colab]]
-Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize
-a pipeline to one's use case. The best example of this is the [Schedulers](../api/schedulers/overview).
+Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize
+a pipeline to one's use case. The best example of this is the [Schedulers](../api/schedulers/overview.md).
-Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample,
+Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample,
schedulers define the whole denoising process, *i.e.*:
- How many denoising steps?
- Stochastic or deterministic?
-- What algorithm to use to find the denoised sample?
+- What algorithm to use to find the denoised sample
They can be quite complex and often define a trade-off between **denoising speed** and **denoising quality**.
It is extremely difficult to measure quantitatively which scheduler works best for a given diffusion pipeline, so it is often recommended to simply try out which works best.
@@ -63,7 +63,7 @@ pipeline.scheduler
```
PNDMScheduler {
"_class_name": "PNDMScheduler",
- "_diffusers_version": "0.21.4",
+ "_diffusers_version": "0.8.0.dev0",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
@@ -72,12 +72,11 @@ PNDMScheduler {
"set_alpha_to_one": false,
"skip_prk_steps": true,
"steps_offset": 1,
- "timestep_spacing": "leading",
"trained_betas": null
}
```
-We can see that the scheduler is of type [`PNDMScheduler`].
+We can see that the scheduler is of type [`PNDMScheduler`].
Cool, now let's compare the scheduler in its performance to other schedulers.
First we define a prompt on which we will test all the different schedulers:
@@ -102,7 +101,7 @@ image
## Changing the scheduler
-Now we show how easy it is to change the scheduler of a pipeline. Every scheduler has a property [`~SchedulerMixin.compatibles`]
+Now we show how easy it is to change the scheduler of a pipeline. Every scheduler has a property [`SchedulerMixin.compatibles`]
which defines all compatible schedulers. You can take a look at all available, compatible schedulers for the Stable Diffusion pipeline as follows.
```python
@@ -111,40 +110,27 @@ pipeline.scheduler.compatibles
**Output**:
```
-[diffusers.utils.dummy_torch_and_torchsde_objects.DPMSolverSDEScheduler,
- diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
- diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
+[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
diffusers.schedulers.scheduling_ddim.DDIMScheduler,
- diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
- diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
- diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
+ diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
diffusers.schedulers.scheduling_pndm.PNDMScheduler,
- diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
- diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
- diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
- diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
- diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler]
+ diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
+ diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler]
```
-Cool, lots of schedulers to look at. Feel free to have a look at their respective class definitions:
-
-- [`EulerDiscreteScheduler`],
-- [`LMSDiscreteScheduler`],
-- [`DDIMScheduler`],
-- [`DDPMScheduler`],
-- [`HeunDiscreteScheduler`],
-- [`DPMSolverMultistepScheduler`],
-- [`DEISMultistepScheduler`],
-- [`PNDMScheduler`],
-- [`EulerAncestralDiscreteScheduler`],
-- [`UniPCMultistepScheduler`],
-- [`KDPM2DiscreteScheduler`],
-- [`DPMSolverSinglestepScheduler`],
-- [`KDPM2AncestralDiscreteScheduler`].
-
-We will now compare the input prompt with all other schedulers. To change the scheduler of the pipeline you can make use of the
-convenient [`~ConfigMixin.config`] property in combination with the [`~ConfigMixin.from_config`] function.
+Cool, lots of schedulers to look at. Feel free to have a look at their respective class definitions:
+
+- [`LMSDiscreteScheduler`],
+- [`DDIMScheduler`],
+- [`DPMSolverMultistepScheduler`],
+- [`EulerDiscreteScheduler`],
+- [`PNDMScheduler`],
+- [`DDPMScheduler`],
+- [`EulerAncestralDiscreteScheduler`].
+
+We will now compare the input prompt with all other schedulers. To change the scheduler of the pipeline you can make use of the
+convenient [`ConfigMixin.config`] property in combination with the [`ConfigMixin.from_config`] function.
```python
pipeline.scheduler.config
@@ -153,7 +139,7 @@ pipeline.scheduler.config
returns a dictionary of the configuration of the scheduler:
**Output**:
-```py
+```
FrozenDict([('num_train_timesteps', 1000),
('beta_start', 0.00085),
('beta_end', 0.012),
@@ -161,17 +147,14 @@ FrozenDict([('num_train_timesteps', 1000),
('trained_betas', None),
('skip_prk_steps', True),
('set_alpha_to_one', False),
- ('prediction_type', 'epsilon'),
- ('timestep_spacing', 'leading'),
('steps_offset', 1),
- ('_use_default_values', ['timestep_spacing', 'prediction_type']),
('_class_name', 'PNDMScheduler'),
- ('_diffusers_version', '0.21.4'),
+ ('_diffusers_version', '0.8.0.dev0'),
('clip_sample', False)])
```
This configuration can then be used to instantiate a scheduler
-of a different class that is compatible with the pipeline. Here,
+of a different class that is compatible with the pipeline. Here,
we change the scheduler to the [`DDIMScheduler`].
```python
@@ -198,8 +181,8 @@ If you are a JAX/Flax user, please check [this section](#changing-the-scheduler-
## Compare schedulers
-So far we have tried running the stable diffusion pipeline with two schedulers: [`PNDMScheduler`] and [`DDIMScheduler`].
-A number of better schedulers have been released that can be run with much fewer steps; let's compare them here:
+So far we have tried running the stable diffusion pipeline with two schedulers: [`PNDMScheduler`] and [`DDIMScheduler`].
+A number of better schedulers have been released that can be run with much fewer steps, let's compare them here:
[`LMSDiscreteScheduler`] usually leads to better results:
@@ -258,7 +241,8 @@ image
-[`DPMSolverMultistepScheduler`] gives a reasonable speed/quality trade-off and can be run with as little as 20 steps.
+At the time of writing this doc [`DPMSolverMultistepScheduler`] gives arguably the best speed/quality trade-off and can be run with as little
+as 20 steps.
```python
from diffusers import DPMSolverMultistepScheduler
@@ -276,12 +260,12 @@ image
-As you can see, most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different
+As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different
schedulers to compare results.
## Changing the Scheduler in Flax
-If you are a JAX/Flax user, you can also change the default pipeline scheduler. This is a complete example of how to run inference using the Flax Stable Diffusion pipeline and the super-fast [DPM-Solver++ scheduler](../api/schedulers/multistep_dpm_solver):
+If you are a JAX/Flax user, you can also change the default pipeline scheduler. This is a complete example of how to run inference using the Flax Stable Diffusion pipeline and the super-fast [DDPM-Solver++ scheduler](../api/schedulers/multistep_dpm_solver):
```Python
import jax
diff --git a/docs/source/en/using-diffusers/sdxl.md b/docs/source/en/using-diffusers/sdxl.md
index 25b581fc6f6f..36286ecad863 100644
--- a/docs/source/en/using-diffusers/sdxl.md
+++ b/docs/source/en/using-diffusers/sdxl.md
@@ -1,15 +1,3 @@
-
-
# Stable Diffusion XL
[[open-in-colab]]
@@ -26,7 +14,7 @@ Before you begin, make sure you have the following libraries installed:
```py
# uncomment to install the necessary libraries in Colab
-#!pip install -q diffusers transformers accelerate omegaconf invisible-watermark>=0.2.0
+#!pip install diffusers transformers accelerate safetensors omegaconf invisible-watermark>=0.2.0
```
@@ -84,8 +72,7 @@ pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
).to("cuda")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-image = pipeline_text2image(prompt=prompt).images[0]
-image
+image = pipeline(prompt=prompt).images[0]
```
@@ -97,17 +84,16 @@ image
For image-to-image, SDXL works especially well with image sizes between 768x768 and 1024x1024. Pass an initial image, and a text prompt to condition the image with:
```py
-from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import load_image, make_image_grid
+from diffusers import AutoPipelineForImg2Img
+from diffusers.utils import load_image
# use from_pipe to avoid consuming additional memory when loading a checkpoint
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_text2image).to("cuda")
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-img2img.png"
-url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
-init_image = load_image(url)
+init_image = load_image(url).convert("RGB")
prompt = "a dog catching a frisbee in the jungle"
image = pipeline(prompt, image=init_image, strength=0.8, guidance_scale=10.5).images[0]
-make_image_grid([init_image, image], rows=1, cols=2)
```
@@ -120,7 +106,7 @@ For inpainting, you'll need the original image and a mask of what you want to re
```py
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image, make_image_grid
+from diffusers.utils import load_image
# use from_pipe to avoid consuming additional memory when loading a checkpoint
pipeline = AutoPipelineForInpainting.from_pipe(pipeline_text2image).to("cuda")
@@ -128,12 +114,11 @@ pipeline = AutoPipelineForInpainting.from_pipe(pipeline_text2image).to("cuda")
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
mask_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png"
-init_image = load_image(img_url)
-mask_image = load_image(mask_url)
+init_image = load_image(img_url).convert("RGB")
+mask_image = load_image(mask_url).convert("RGB")
prompt = "A deep sea diver floating"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, guidance_scale=12.5).images[0]
-make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
@@ -144,12 +129,12 @@ make_image_grid([init_image, mask_image, image], rows=1, cols=3)
SDXL includes a [refiner model](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0) specialized in denoising low-noise stage images to generate higher-quality images from the base model. There are two ways to use the refiner:
-1. use the base and refiner models together to produce a refined image
-2. use the base model to produce an image, and subsequently use the refiner model to add more details to the image (this is how SDXL was originally trained)
+1. use the base and refiner model together to produce a refined image
+2. use the base model to produce an image, and subsequently use the refiner model to add more details to the image (this is how SDXL is originally trained)
### Base + refiner model
-When you use the base and refiner model together to generate an image, this is known as an [*ensemble of expert denoisers*](https://research.nvidia.com/labs/dir/eDiff-I/). The ensemble of expert denoisers approach requires fewer overall denoising steps versus passing the base model's output to the refiner model, so it should be significantly faster to run. However, you won't be able to inspect the base model's output because it still contains a large amount of noise.
+When you use the base and refiner model together to generate an image, this is known as an ([*ensemble of expert denoisers*](https://research.nvidia.com/labs/dir/eDiff-I/)). The ensemble of expert denoisers approach requires less overall denoising steps versus passing the base model's output to the refiner model, so it should be significantly faster to run. However, you won't be able to inspect the base model's output because it still contains a large amount of noise.
As an ensemble of expert denoisers, the base model serves as the expert during the high-noise diffusion stage and the refiner model serves as the expert during the low-noise diffusion stage. Load the base and refiner model:
@@ -196,13 +181,12 @@ image = refiner(
denoising_start=0.8,
image=image,
).images[0]
-image
```

-
default base model
+
base model

@@ -214,8 +198,7 @@ The refiner model can also be used for inpainting in the [`StableDiffusionXLInpa
```py
from diffusers import StableDiffusionXLInpaintPipeline
-from diffusers.utils import load_image, make_image_grid
-import torch
+from diffusers.utils import load_image
base = StableDiffusionXLInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
@@ -223,8 +206,8 @@ base = StableDiffusionXLInpaintPipeline.from_pretrained(
refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
- text_encoder_2=base.text_encoder_2,
- vae=base.vae,
+ text_encoder_2=pipe.text_encoder_2,
+ vae=pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
@@ -233,8 +216,8 @@ refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
-init_image = load_image(img_url)
-mask_image = load_image(mask_url)
+init_image = load_image(img_url).convert("RGB")
+mask_image = load_image(mask_url).convert("RGB")
prompt = "A majestic tiger sitting on a bench"
num_inference_steps = 75
@@ -255,7 +238,6 @@ image = refiner(
num_inference_steps=num_inference_steps,
denoising_start=high_noise_frac,
).images[0]
-make_image_grid([init_image, mask_image, image.resize((512, 512))], rows=1, cols=3)
```
This ensemble of expert denoisers method works well for all available schedulers!
@@ -276,8 +258,8 @@ base = DiffusionPipeline.from_pretrained(
refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
- text_encoder_2=base.text_encoder_2,
- vae=base.vae,
+ text_encoder_2=pipe.text_encoder_2,
+ vae=pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
@@ -309,7 +291,7 @@ image = refiner(prompt=prompt, image=image[None, :]).images[0]
-For inpainting, load the base and the refiner model in the [`StableDiffusionXLInpaintPipeline`], remove the `denoising_end` and `denoising_start` parameters, and choose a smaller number of inference steps for the refiner.
+For inpainting, load the refiner model in the [`StableDiffusionXLInpaintPipeline`], remove the `denoising_end` and `denoising_start` parameters, and choose a smaller number of inference steps for the refiner.
## Micro-conditioning
@@ -349,7 +331,7 @@ image = pipe(

-
Images negatively conditioned on image resolutions of (128, 128), (256, 256), and (512, 512).
+
Images negative conditioned on image resolutions of (128, 128), (256, 256), and (512, 512).
### Crop conditioning
@@ -360,13 +342,13 @@ Images generated by previous Stable Diffusion models may sometimes appear to be
from diffusers import StableDiffusionXLPipeline
import torch
+
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-image = pipeline(prompt=prompt, crops_coords_top_left=(256, 0)).images[0]
-image
+image = pipeline(prompt=prompt, crops_coords_top_left=(256,0)).images[0]
```
@@ -390,12 +372,11 @@ image = pipe(
negative_crops_coords_top_left=(0, 0),
negative_target_size=(1024, 1024),
).images[0]
-image
```
## Use a different prompt for each text-encoder
-SDXL uses two text-encoders, so it is possible to pass a different prompt to each text-encoder, which can [improve quality](https://github.com/huggingface/diffusers/issues/4004#issuecomment-1627764201). Pass your original prompt to `prompt` and the second prompt to `prompt_2` (use `negative_prompt` and `negative_prompt_2` if you're using negative prompts):
+SDXL uses two text-encoders, so it is possible to pass a different prompt to each text-encoder, which can [improve quality](https://github.com/huggingface/diffusers/issues/4004#issuecomment-1627764201). Pass your original prompt to `prompt` and the second prompt to `prompt_2` (use `negative_prompt` and `negative_prompt_2` if you're using a negative prompts):
```py
from diffusers import StableDiffusionXLPipeline
@@ -410,14 +391,13 @@ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# prompt_2 is passed to OpenCLIP-ViT/bigG-14
prompt_2 = "Van Gogh painting"
image = pipeline(prompt=prompt, prompt_2=prompt_2).images[0]
-image
```
-The dual text-encoders also support textual inversion embeddings that need to be loaded separately as explained in the [SDXL textual inversion](textual_inversion_inference#stable-diffusion-xl) section.
+The dual text-encoders also support textual inversion embeddings that need to be loaded separately as explained in the [SDXL textual inversion](textual_inversion_inference#stable-diffusion-xl] section.
## Optimizations
@@ -428,18 +408,18 @@ SDXL is a large model, and you may need to optimize memory to get it to run on y
```diff
- base.to("cuda")
- refiner.to("cuda")
-+ base.enable_model_cpu_offload()
-+ refiner.enable_model_cpu_offload()
++ base.enable_model_cpu_offload
++ refiner.enable_model_cpu_offload
```
-2. Use `torch.compile` for ~20% speed-up (you need `torch>=2.0`):
+2. Use `torch.compile` for ~20% speed-up (you need `torch>2.0`):
```diff
+ base.unet = torch.compile(base.unet, mode="reduce-overhead", fullgraph=True)
+ refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
```
-3. Enable [xFormers](../optimization/xformers) to run SDXL if `torch<2.0`:
+3. Enable [xFormers](/optimization/xformers) to run SDXL if `torch<2.0`:
```diff
+ base.enable_xformers_memory_efficient_attention()
diff --git a/docs/source/en/using-diffusers/shap-e.md b/docs/source/en/using-diffusers/shap-e.md
index f0ce977584a5..b74a652582ec 100644
--- a/docs/source/en/using-diffusers/shap-e.md
+++ b/docs/source/en/using-diffusers/shap-e.md
@@ -1,22 +1,10 @@
-
-
# Shap-E
[[open-in-colab]]
Shap-E is a conditional model for generating 3D assets which could be used for video game development, interior design, and architecture. It is trained on a large dataset of 3D assets, and post-processed to render more views of each object and produce 16K instead of 4K point clouds. The Shap-E model is trained in two steps:
-1. an encoder accepts the point clouds and rendered views of a 3D asset and outputs the parameters of implicit functions that represent the asset
+1. a encoder accepts the point clouds and rendered views of a 3D asset and outputs the parameters of implicit functions that represent the asset
2. a diffusion model is trained on the latents produced by the encoder to generate either neural radiance fields (NeRFs) or a textured 3D mesh, making it easier to render and use the 3D asset in downstream applications
This guide will show you how to use Shap-E to start generating your own 3D assets!
@@ -25,7 +13,7 @@ Before you begin, make sure you have the following libraries installed:
```py
# uncomment to install the necessary libraries in Colab
-#!pip install -q diffusers transformers accelerate trimesh
+#!pip install diffusers transformers accelerate safetensors trimesh
```
## Text-to-3D
@@ -38,7 +26,7 @@ from diffusers import ShapEPipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16, variant="fp16")
+pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
pipe = pipe.to(device)
guidance_scale = 15.0
@@ -64,17 +52,17 @@ export_to_gif(images[1], "cake_3d.gif")

-
prompt = "A firecracker"
+
firecracker

-
prompt = "A birthday cupcake"
+
cupcake
## Image-to-3D
-To generate a 3D object from another image, use the [`ShapEImg2ImgPipeline`]. You can use an existing image or generate an entirely new one. Let's use the [Kandinsky 2.1](../api/pipelines/kandinsky) model to generate a new image.
+To generate a 3D object from another image, use the [`ShapEImg2ImgPipeline`]. You can use an existing image or generate an entirely new one. Let's use the the [Kandinsky 2.1](../api/pipelines/kandinsky) model to generate a new image.
```py
from diffusers import DiffusionPipeline
@@ -99,7 +87,6 @@ Pass the cheeseburger to the [`ShapEImg2ImgPipeline`] to generate a 3D represent
```py
from PIL import Image
-from diffusers import ShapEImg2ImgPipeline
from diffusers.utils import export_to_gif
pipe = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16, variant="fp16").to("cuda")
@@ -140,7 +127,7 @@ from diffusers import ShapEPipeline
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16, variant="fp16")
+pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
pipe = pipe.to(device)
guidance_scale = 15.0
@@ -161,7 +148,7 @@ You can optionally save the mesh output as an `obj` file with the [`~utils.expor
from diffusers.utils import export_to_ply
ply_path = export_to_ply(images[0], "3d_cake.ply")
-print(f"Saved to folder: {ply_path}")
+print(f"saved to folder: {ply_path}")
```
Then you can convert the `ply` file to a `glb` file with the trimesh library:
@@ -170,7 +157,7 @@ Then you can convert the `ply` file to a `glb` file with the trimesh library:
import trimesh
mesh = trimesh.load("3d_cake.ply")
-mesh_export = mesh.export("3d_cake.glb", file_type="glb")
+mesh.export("3d_cake.glb", file_type="glb")
```
By default, the mesh output is focused from the bottom viewpoint but you can change the default viewpoint by applying a rotation transform:
@@ -182,11 +169,11 @@ import numpy as np
mesh = trimesh.load("3d_cake.ply")
rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
mesh = mesh.apply_transform(rot)
-mesh_export = mesh.export("3d_cake.glb", file_type="glb")
+mesh.export("3d_cake.glb", file_type="glb")
```
Upload the mesh file to your dataset repository to visualize it with the Dataset viewer!

-
+
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md b/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md
index 6f75ba2c3999..d62ce0bf91bf 100644
--- a/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md
+++ b/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md
@@ -1,15 +1,3 @@
-
-
# JAX/Flax
[[open-in-colab]]
@@ -38,20 +26,25 @@ device_type = jax.devices()[0].device_kind
print(f"Found {num_devices} JAX devices of type {device_type}.")
assert (
- "TPU" in device_type,
- "Available device is not a TPU, please select TPU from Runtime > Change runtime type > Hardware accelerator"
+ "TPU" in device_type,
+ "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
)
-# Found 8 JAX devices of type Cloud TPU.
+"Found 8 JAX devices of type Cloud TPU."
```
Great, now you can import the rest of the dependencies you'll need:
```python
+import numpy as np
import jax.numpy as jnp
+
+from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
+from PIL import Image
+from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline
```
@@ -85,7 +78,7 @@ prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, por
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
-# (8, 77)
+"(8, 77)"
```
Model parameters and inputs have to be replicated across the 8 parallel devices. The parameters dictionary is replicated with [`flax.jax_utils.replicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.jax_utils.html#flax.jax_utils.replicate) which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using `shard`.
@@ -97,7 +90,7 @@ p_params = replicate(params)
# arrays
prompt_ids = shard(prompt_ids)
prompt_ids.shape
-# (8, 1, 77)
+"(8, 1, 77)"
```
This shape means each one of the 8 devices receives as an input a `jnp` array with shape `(1, 77)`, where `1` is the batch size per device. On TPUs with sufficient memory, you could have a batch size larger than `1` if you want to generate multiple images (per chip) at once.
@@ -122,7 +115,7 @@ To take advantage of JAX's optimized speed on a TPU, pass `jit=True` to the pipe
-You need to ensure all your inputs have the same shape in subsequent calls, otherwise JAX will need to recompile the code which is slower.
+You need to ensure all your inputs have the same shape in subsequent calls, other JAX will need to recompile the code which is slower.
@@ -132,18 +125,18 @@ The first inference run takes more time because it needs to compile the code, bu
%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
-# CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
-# Wall time: 1min 29s
+"CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s"
+"Wall time: 1min 29s"
```
The returned array has shape `(8, 1, 512, 512, 3)` which should be reshaped to remove the second dimension and get 8 images of `512 × 512 × 3`. Then you can use the [`~utils.numpy_to_pil`] function to convert the arrays into images.
```python
-from diffusers.utils import make_image_grid
+from diffusers import make_image_grid
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
-make_image_grid(images, rows=2, cols=4)
+make_image_grid(images, 2, 4)
```

@@ -176,6 +169,7 @@ make_image_grid(images, 2, 4)

+
## How does parallelization work?
The Flax pipeline in 🤗 Diffusers automatically compiles the model and runs it in parallel on all available devices. Let's take a closer look at how that process works.
@@ -196,7 +190,7 @@ p_generate = pmap(pipeline._generate)
After calling `pmap`, the prepared function `p_generate` will:
1. Make a copy of the underlying function, `pipeline._generate`, on each device.
-2. Send each device a different portion of the input arguments (this is why it's necessary to call the *shard* function). In this case, `prompt_ids` has shape `(8, 1, 77, 768)` so the array is split into 8 and each copy of `_generate` receives an input with shape `(1, 77, 768)`.
+2. Send each device a different portion of the input arguments (this is why its necessary to call the *shard* function). In this case, `prompt_ids` has shape `(8, 1, 77, 768)` so the array is split into 8 and each copy of `_generate` receives an input with shape `(1, 77, 768)`.
The most important thing to pay attention to here is the batch size (1 in this example), and the input dimensions that make sense for your code. You don't have to change anything else to make the code work in parallel.
@@ -206,14 +200,13 @@ The first time you call the pipeline takes more time, but the calls afterward ar
%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
-
-# CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
-# Wall time: 1min 15s
+"CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s"
+"Wall time: 1min 15s"
```
Check your image dimensions to see if they're correct:
```python
images.shape
-# (8, 1, 512, 512, 3)
-```
+"(8, 1, 512, 512, 3)"
+```
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/textual_inversion_inference.md b/docs/source/en/using-diffusers/textual_inversion_inference.md
index 084101c06ba3..0ca4ecc58d4e 100644
--- a/docs/source/en/using-diffusers/textual_inversion_inference.md
+++ b/docs/source/en/using-diffusers/textual_inversion_inference.md
@@ -1,29 +1,31 @@
-
-
# Textual inversion
[[open-in-colab]]
The [`StableDiffusionPipeline`] supports textual inversion, a technique that enables a model like Stable Diffusion to learn a new concept from just a few sample images. This gives you more control over the generated images and allows you to tailor the model towards specific concepts. You can get started quickly with a collection of community created concepts in the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer).
-This guide will show you how to run inference with textual inversion using a pre-learned concept from the Stable Diffusion Conceptualizer. If you're interested in teaching a model new concepts with textual inversion, take a look at the [Textual Inversion](../training/text_inversion) training guide.
+This guide will show you how to run inference with textual inversion using a pre-learned concept from the Stable Diffusion Conceptualizer. If you're interested in teaching a model new concepts with textual inversion, take a look at the [Textual Inversion](./training/text_inversion) training guide.
+
+Login to your Hugging Face account:
+
+```py
+from huggingface_hub import notebook_login
+
+notebook_login()
+```
Import the necessary libraries:
```py
+import os
import torch
+
+import PIL
+from PIL import Image
+
from diffusers import StableDiffusionPipeline
from diffusers.utils import make_image_grid
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
```
## Stable Diffusion 1 and 2
@@ -50,7 +52,7 @@ Create a prompt with the pre-learned concept by using the special placeholder to
```py
prompt = "a grafitti in a favela wall with a
on it"
-num_samples_per_row = 2
+num_samples = 2
num_rows = 2
```
@@ -59,10 +61,10 @@ Then run the pipeline (feel free to adjust the parameters like `num_inference_st
```py
all_images = []
for _ in range(num_rows):
- images = pipeline(prompt, num_images_per_prompt=num_samples_per_row, num_inference_steps=50, guidance_scale=7.5).images
+ images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=50, guidance_scale=7.5).images
all_images.extend(images)
-grid = make_image_grid(all_images, num_rows, num_samples_per_row)
+grid = make_image_grid(all_images, num_samples, num_rows)
grid
```
@@ -70,6 +72,7 @@ grid
+
## Stable Diffusion XL
Stable Diffusion XL (SDXL) can also use textual inversion vectors for inference. In contrast to Stable Diffusion 1 and 2, SDXL has two text encoders so you'll need two textual inversion embeddings - one for each text encoder model.
@@ -94,9 +97,9 @@ state_dict
[ 0.0475, -0.0508, -0.0145, ..., 0.0070, -0.0089, -0.0163]],
```
-There are two tensors, `"clip_g"` and `"clip_l"`.
-`"clip_g"` corresponds to the bigger text encoder in SDXL and refers to
-`pipe.text_encoder_2` and `"clip_l"` refers to `pipe.text_encoder`.
+There are two tensors, `"clip-g"` and `"clip-l"`.
+`"clip-g"` corresponds to the bigger text encoder in SDXL and refers to
+`pipe.text_encoder_2` and `"clip-l"` refers to `pipe.text_encoder`.
Now you can load each tensor separately by passing them along with the correct text encoder and tokenizer
to [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`]:
@@ -114,5 +117,4 @@ pipe.load_textual_inversion(state_dict["clip_l"], token="unaestheticXLv31", text
# the embedding should be used as a negative embedding, so we pass it as a negative prompt
generator = torch.Generator().manual_seed(33)
image = pipe("a woman standing in front of a mountain", negative_prompt="unaestheticXLv31", generator=generator).images[0]
-image
```
diff --git a/docs/source/en/using-diffusers/unconditional_image_generation.md b/docs/source/en/using-diffusers/unconditional_image_generation.md
index 6c55c4edec08..3893f7cce276 100644
--- a/docs/source/en/using-diffusers/unconditional_image_generation.md
+++ b/docs/source/en/using-diffusers/unconditional_image_generation.md
@@ -14,42 +14,56 @@ specific language governing permissions and limitations under the License.
[[open-in-colab]]
-Unconditional image generation generates images that look like a random sample from the training data the model was trained on because the denoising process is not guided by any additional context like text or image.
+Unconditional image generation is a relatively straightforward task. The model only generates images - without any additional context like text or an image - resembling the training data it was trained on.
-To get started, use the [`DiffusionPipeline`] to load the [anton-l/ddpm-butterflies-128](https://huggingface.co/anton-l/ddpm-butterflies-128) checkpoint to generate images of butterflies. The [`DiffusionPipeline`] downloads and caches all the model components required to generate an image.
+The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference.
-```py
-from diffusers import DiffusionPipeline
-
-generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128").to("cuda")
-image = generator().images[0]
-image
-```
+Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download.
+You can use any of the 🧨 Diffusers [checkpoints](https://huggingface.co/models?library=diffusers&sort=downloads) from the Hub (the checkpoint you'll use generates images of butterflies).
-Want to generate images of something else? Take a look at the training [guide](../training/unconditional_training) to learn how to train a model to generate your own images.
+💡 Want to train your own unconditional image generation model? Take a look at the training [guide](training/unconditional_training) to learn how to generate your own images.
-The output image is a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object that can be saved:
+In this guide, you'll use [`DiffusionPipeline`] for unconditional image generation with [DDPM](https://arxiv.org/abs/2006.11239):
+
+```python
+>>> from diffusers import DiffusionPipeline
+
+>>> generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128", use_safetensors=True)
+```
+
+The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
+Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on a GPU.
+You can move the generator object to a GPU, just like you would in PyTorch:
-```py
-image.save("generated_image.png")
+```python
+>>> generator.to("cuda")
```
-You can also try experimenting with the `num_inference_steps` parameter, which controls the number of denoising steps. More denoising steps typically produce higher quality images, but it'll take longer to generate. Feel free to play around with this parameter to see how it affects the image quality.
+Now you can use the `generator` to generate an image:
-```py
-image = generator(num_inference_steps=100).images[0]
-image
+```python
+>>> image = generator().images[0]
```
-Try out the Space below to generate an image of a butterfly!
+The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object.
+
+You can save the image by calling:
+
+```python
+>>> image.save("generated_image.png")
+```
+
+Try out the Spaces below, and feel free to play around with the inference steps parameter to see how it affects the image quality!
+
+
diff --git a/docs/source/en/using-diffusers/using_safetensors.md b/docs/source/en/using-diffusers/using_safetensors.md
index 3e89e7eed9a0..2f47eb08cb83 100644
--- a/docs/source/en/using-diffusers/using_safetensors.md
+++ b/docs/source/en/using-diffusers/using_safetensors.md
@@ -1,15 +1,3 @@
-
-
# Load safetensors
[[open-in-colab]]
@@ -67,11 +55,11 @@ There are several reasons for using safetensors:
The time it takes to load the entire pipeline:
```py
- from diffusers import StableDiffusionPipeline
+ from diffusers import StableDiffusionPipeline
- pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", use_safetensors=True)
- "Loaded in safetensors 0:00:02.033658"
- "Loaded in PyTorch 0:00:02.663379"
+ pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", use_safetensors=True)
+ "Loaded in safetensors 0:00:02.033658"
+ "Loaded in PyTorch 0:00:02.663379"
```
But the actual time it takes to load 500MB of the model weights is only:
diff --git a/docs/source/en/using-diffusers/weighted_prompts.md b/docs/source/en/using-diffusers/weighted_prompts.md
index 947d18b86ec8..ede2c7f35169 100644
--- a/docs/source/en/using-diffusers/weighted_prompts.md
+++ b/docs/source/en/using-diffusers/weighted_prompts.md
@@ -41,7 +41,6 @@ import torch
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_safetensors=True)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
-pipe.to("cuda")
prompt = "a red cat playing with a ball"
@@ -142,7 +141,7 @@ image
## Conjunction
A conjunction diffuses each prompt independently and concatenates their results by their weighted sum. Add `.and()` to the end of a list of prompts to create a conjunction:
-
+
```py
prompt_embeds = compel_proc('["a red cat", "playing with a", "ball"].and()')
generator = torch.Generator(device="cuda").manual_seed(55)
@@ -166,9 +165,7 @@ import torch
from diffusers import StableDiffusionPipeline
from compel import Compel, DiffusersTextualInversionManager
-pipe = StableDiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16,
- use_safetensors=True, variant="fp16").to("cuda")
+pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, variant="fp16").to("cuda")
pipe.load_textual_inversion("sd-concepts-library/midjourney-style")
```
@@ -176,7 +173,7 @@ Compel provides a `DiffusersTextualInversionManager` class to simplify prompt we
```py
textual_inversion_manager = DiffusersTextualInversionManager(pipe)
-compel_proc = Compel(
+compel = Compel(
tokenizer=pipe.tokenizer,
text_encoder=pipe.text_encoder,
textual_inversion_manager=textual_inversion_manager)
@@ -228,8 +225,6 @@ Stable Diffusion XL (SDXL) has two tokenizers and text encoders so it's usage is
```py
from compel import Compel, ReturnedEmbeddingsType
from diffusers import DiffusionPipeline
-from diffusers.utils import make_image_grid
-import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
@@ -256,7 +251,6 @@ conditioning, pooled = compel(prompt)
# generate image
generator = [torch.Generator().manual_seed(33) for _ in range(len(prompt))]
images = pipeline(prompt_embeds=conditioning, pooled_prompt_embeds=pooled, generator=generator, num_inference_steps=30).images
-make_image_grid(images, rows=1, cols=2)
```
@@ -268,4 +262,4 @@ make_image_grid(images, rows=1, cols=2)
"a red cat playing with a (ball)0.6"
-
+
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/write_own_pipeline.md b/docs/source/en/using-diffusers/write_own_pipeline.md
index 4ca3fe33223b..a9243a7b9adc 100644
--- a/docs/source/en/using-diffusers/write_own_pipeline.md
+++ b/docs/source/en/using-diffusers/write_own_pipeline.md
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
[[open-in-colab]]
-🧨 Diffusers is designed to be a user-friendly and flexible toolbox for building diffusion systems tailored to your use-case. At the core of the toolbox are models and schedulers. While the [`DiffusionPipeline`] bundles these components together for convenience, you can also unbundle the pipeline and use the models and schedulers separately to create new diffusion systems.
+🧨 Diffusers is designed to be a user-friendly and flexible toolbox for building diffusion systems tailored to your use-case. At the core of the toolbox are models and schedulers. While the [`DiffusionPipeline`] bundles these components together for convenience, you can also unbundle the pipeline and use the models and schedulers separately to create new diffusion systems.
In this tutorial, you'll learn how to use models and schedulers to assemble a diffusion system for inference, starting with a basic pipeline and then progressing to the Stable Diffusion pipeline.
@@ -36,7 +36,7 @@ A pipeline is a quick and easy way to run a model for inference, requiring no mo
That was super easy, but how did the pipeline do that? Let's breakdown the pipeline and take a look at what's happening under the hood.
-In the example above, the pipeline contains a [`UNet2DModel`] model and a [`DDPMScheduler`]. The pipeline denoises an image by taking random noise the size of the desired output and passing it through the model several times. At each timestep, the model predicts the *noise residual* and the scheduler uses it to predict a less noisy image. The pipeline repeats this process until it reaches the end of the specified number of inference steps.
+In the example above, the pipeline contains a [`UNet2DModel`] model and a [`DDPMScheduler`]. The pipeline denoises an image by taking random noise the size of the desired output and passing it through the model several times. At each timestep, the model predicts the *noise residual* and the scheduler uses it to predict a less noisy image. The pipeline repeats this process until it reaches the end of the specified number of inference steps.
To recreate the pipeline with the model and scheduler separately, let's write our own denoising process.
@@ -71,7 +71,7 @@ tensor([980, 960, 940, 920, 900, 880, 860, 840, 820, 800, 780, 760, 740, 720,
>>> import torch
>>> sample_size = model.config.sample_size
->>> noise = torch.randn((1, 3, sample_size, sample_size), device="cuda")
+>>> noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
```
5. Now write a loop to iterate over the timesteps. At each timestep, the model does a [`UNet2DModel.forward`] pass and returns the noisy residual. The scheduler's [`~DDPMScheduler.step`] method takes the noisy residual, timestep, and input and it predicts the image at the previous timestep. This output becomes the next input to the model in the denoising loop, and it'll repeat until it reaches the end of the `timesteps` array.
@@ -153,7 +153,7 @@ To speed up inference, move the models to a GPU since, unlike the scheduler, the
### Create text embeddings
-The next step is to tokenize the text to generate embeddings. The text is used to condition the UNet model and steer the diffusion process towards something that resembles the input prompt.
+The next step is to tokenize the text to generate embeddings. The text is used to condition the UNet model and steer the diffusion process towards something that resembles the input prompt.
@@ -169,7 +169,7 @@ Feel free to choose any prompt you like if you want to generate something else!
>>> width = 512 # default width of Stable Diffusion
>>> num_inference_steps = 25 # Number of denoising steps
>>> guidance_scale = 7.5 # Scale for classifier-free guidance
->>> generator = torch.manual_seed(0) # Seed generator to create the initial latent noise
+>>> generator = torch.manual_seed(0) # Seed generator to create the inital latent noise
>>> batch_size = len(prompt)
```
@@ -216,8 +216,8 @@ Next, generate some initial random noise as a starting point for the diffusion p
>>> latents = torch.randn(
... (batch_size, unet.config.in_channels, height // 8, width // 8),
... generator=generator,
-... device=torch_device,
... )
+>>> latents = latents.to(torch_device)
```
### Denoise the image
@@ -284,11 +284,11 @@ Lastly, convert the image to a `PIL.Image` to see your generated image!
## Next steps
-From basic to complex pipelines, you've seen that all you really need to write your own diffusion system is a denoising loop. The loop should set the scheduler's timesteps, iterate over them, and alternate between calling the UNet model to predict the noise residual and passing it to the scheduler to compute the previous noisy sample.
+From basic to complex pipelines, you've seen that all you really need to write your own diffusion system is a denoising loop. The loop should set the scheduler's timesteps, iterate over them, and alternate between calling the UNet model to predict the noise residual and passing it to the scheduler to compute the previous noisy sample.
This is really what 🧨 Diffusers is designed for: to make it intuitive and easy to write your own diffusion system using models and schedulers.
For your next steps, feel free to:
-* Learn how to [build and contribute a pipeline](../using-diffusers/contribute_pipeline) to 🧨 Diffusers. We can't wait and see what you'll come up with!
+* Learn how to [build and contribute a pipeline](contribute_pipeline) to 🧨 Diffusers. We can't wait and see what you'll come up with!
* Explore [existing pipelines](../api/pipelines/overview) in the library, and see if you can deconstruct and build a pipeline from scratch using the models and schedulers separately.
diff --git a/docs/source/ko/optimization/fp16.md b/docs/source/ko/optimization/fp16.md
index 0f2c487a75ce..30197305540c 100644
--- a/docs/source/ko/optimization/fp16.md
+++ b/docs/source/ko/optimization/fp16.md
@@ -273,9 +273,9 @@ unet_runs_per_experiment = 50
# 입력 불러오기
def generate_inputs():
- sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
- timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999
- encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
+ sample = torch.randn(2, 4, 64, 64).half().cuda()
+ timestep = torch.rand(1).half().cuda() * 999
+ encoder_hidden_states = torch.randn(2, 77, 768).half().cuda()
return sample, timestep, encoder_hidden_states
diff --git a/docs/source/ko/tutorials/basic_training.md b/docs/source/ko/tutorials/basic_training.md
index 1cc82d2b8ce6..e18c82c4fd4b 100644
--- a/docs/source/ko/tutorials/basic_training.md
+++ b/docs/source/ko/tutorials/basic_training.md
@@ -283,27 +283,36 @@ TensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽
```py
>>> from accelerate import Accelerator
->>> from huggingface_hub import create_repo, upload_folder
+>>> from huggingface_hub import HfFolder, Repository, whoami
>>> from tqdm.auto import tqdm
>>> from pathlib import Path
>>> import os
+>>> def get_full_repo_name(model_id: str, organization: str = None, token: str = None):
+... if token is None:
+... token = HfFolder.get_token()
+... if organization is None:
+... username = whoami(token)["name"]
+... return f"{username}/{model_id}"
+... else:
+... return f"{organization}/{model_id}"
+
+
>>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
-... # Initialize accelerator and tensorboard logging
+... # accelerator와 tensorboard 로깅 초기화
... accelerator = Accelerator(
... mixed_precision=config.mixed_precision,
... gradient_accumulation_steps=config.gradient_accumulation_steps,
... log_with="tensorboard",
-... project_dir=os.path.join(config.output_dir, "logs"),
+... logging_dir=os.path.join(config.output_dir, "logs"),
... )
... if accelerator.is_main_process:
-... if config.output_dir is not None:
-... os.makedirs(config.output_dir, exist_ok=True)
... if config.push_to_hub:
-... repo_id = create_repo(
-... repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
-... ).repo_id
+... repo_name = get_full_repo_name(Path(config.output_dir).name)
+... repo = Repository(config.output_dir, clone_from=repo_name)
+... elif config.output_dir is not None:
+... os.makedirs(config.output_dir, exist_ok=True)
... accelerator.init_trackers("train_example")
... # 모든 것이 준비되었습니다.
@@ -322,14 +331,13 @@ TensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽
... for step, batch in enumerate(train_dataloader):
... clean_images = batch["images"]
... # 이미지에 더할 노이즈를 샘플링합니다.
-... noise = torch.randn(clean_images.shape, device=clean_images.device)
+... noise = torch.randn(clean_images.shape).to(clean_images.device)
... bs = clean_images.shape[0]
... # 각 이미지를 위한 랜덤한 타임스텝(timestep)을 샘플링합니다.
... timesteps = torch.randint(
-... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,
-... dtype=torch.int64
-... )
+... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device
+... ).long()
... # 각 타임스텝의 노이즈 크기에 따라 깨끗한 이미지에 노이즈를 추가합니다.
... # (이는 foward diffusion 과정입니다.)
@@ -361,12 +369,7 @@ TensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽
... if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
... if config.push_to_hub:
-... upload_folder(
-... repo_id=repo_id,
-... folder_path=config.output_dir,
-... commit_message=f"Epoch {epoch}",
-... ignore_patterns=["step_*", "epoch_*"],
-... )
+... repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=True)
... else:
... pipeline.save_pretrained(config.output_dir)
```
diff --git a/docs/source/ko/using-diffusers/write_own_pipeline.md b/docs/source/ko/using-diffusers/write_own_pipeline.md
index 787c8113bf0d..a6469644566c 100644
--- a/docs/source/ko/using-diffusers/write_own_pipeline.md
+++ b/docs/source/ko/using-diffusers/write_own_pipeline.md
@@ -71,7 +71,7 @@ specific language governing permissions and limitations under the License.
>>> import torch
>>> sample_size = model.config.sample_size
- >>> noise = torch.randn((1, 3, sample_size, sample_size), device="cuda")
+ >>> noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
```
5. 이제 timestep을 반복하는 루프를 작성합니다. 각 timestep에서 모델은 [`UNet2DModel.forward`]를 통해 noisy residual을 반환합니다. 스케줄러의 [`~DDPMScheduler.step`] 메서드는 noisy residual, timestep, 그리고 입력을 받아 이전 timestep에서 이미지를 예측합니다. 이 출력은 노이즈 제거 루프의 모델에 대한 다음 입력이 되며, `timesteps` 배열의 끝에 도달할 때까지 반복됩니다.
@@ -212,8 +212,8 @@ Stable Diffusion 은 text-to-image *latent diffusion* 모델입니다. latent di
>>> latents = torch.randn(
... (batch_size, unet.in_channels, height // 8, width // 8),
... generator=generator,
-... device=torch_device,
... )
+>>> latents = latents.to(torch_device)
```
### 이미지 노이즈 제거
diff --git a/docs/source/zh/stable_diffusion.md b/docs/source/zh/stable_diffusion.md
index e28607b09032..8a740a8b44eb 100644
--- a/docs/source/zh/stable_diffusion.md
+++ b/docs/source/zh/stable_diffusion.md
@@ -1,264 +1,264 @@
-
-
-# 有效且高效的扩散
-
-[[open-in-colab]]
-
-让 [`DiffusionPipeline`] 生成特定风格或包含你所想要的内容的图像可能会有些棘手。 通常情况下,你需要多次运行 [`DiffusionPipeline`] 才能得到满意的图像。但是从无到有生成图像是一个计算密集的过程,特别是如果你要一遍又一遍地进行推理运算。
-
-这就是为什么从pipeline中获得最高的 *computational* (speed) 和 *memory* (GPU RAM) 非常重要 ,以减少推理周期之间的时间,从而使迭代速度更快。
-
-
-本教程将指导您如何通过 [`DiffusionPipeline`] 更快、更好地生成图像。
-
-
-首先,加载 [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) 模型:
-
-```python
-from diffusers import DiffusionPipeline
-
-model_id = "runwayml/stable-diffusion-v1-5"
-pipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
-```
-
-本教程将使用的提示词是 [`portrait photo of a old warrior chief`] ,但是你可以随心所欲的想象和构造自己的提示词:
-
-```python
-prompt = "portrait photo of a old warrior chief"
-```
-
-## 速度
-
-
-
-💡 如果你没有 GPU, 你可以从像 [Colab](https://colab.research.google.com/) 这样的 GPU 提供商获取免费的 GPU !
-
-
-
-加速推理的最简单方法之一是将 pipeline 放在 GPU 上 ,就像使用任何 PyTorch 模块一样:
-
-```python
-pipeline = pipeline.to("cuda")
-```
-
-为了确保您可以使用相同的图像并对其进行改进,使用 [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) 方法,然后设置一个随机数种子 以确保其 [复现性](./using-diffusers/reproducibility):
-
-```python
-import torch
-
-generator = torch.Generator("cuda").manual_seed(0)
-```
-
-现在,你可以生成一个图像:
-
-```python
-image = pipeline(prompt, generator=generator).images[0]
-image
-```
-
-
-

-
-
-在 T4 GPU 上,这个过程大概要30秒(如果你的 GPU 比 T4 好,可能会更快)。在默认情况下,[`DiffusionPipeline`] 使用完整的 `float32` 精度进行 50 步推理。你可以通过降低精度(如 `float16` )或者减少推理步数来加速整个过程
-
-
-让我们把模型的精度降低至 `float16` ,然后生成一张图像:
-
-```python
-import torch
-
-pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)
-pipeline = pipeline.to("cuda")
-generator = torch.Generator("cuda").manual_seed(0)
-image = pipeline(prompt, generator=generator).images[0]
-image
-```
-
-
-

-
-
-这一次,生成图像只花了约 11 秒,比之前快了近 3 倍!
-
-
-
-💡 我们强烈建议把 pipeline 精度降低至 `float16` , 到目前为止, 我们很少看到输出质量有任何下降。
-
-
-
-另一个选择是减少推理步数。 你可以选择一个更高效的调度器 (*scheduler*) 可以减少推理步数同时保证输出质量。您可以在 [DiffusionPipeline] 中通过调用compatibles方法找到与当前模型兼容的调度器 (*scheduler*)。
-
-```python
-pipeline.scheduler.compatibles
-[
- diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
- diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
- diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
- diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
- diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
- diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
- diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
- diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
- diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,
- diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
- diffusers.schedulers.scheduling_pndm.PNDMScheduler,
- diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
- diffusers.schedulers.scheduling_ddim.DDIMScheduler,
-]
-```
-
-Stable Diffusion 模型默认使用的是 [`PNDMScheduler`] ,通常要大概50步推理, 但是像 [`DPMSolverMultistepScheduler`] 这样更高效的调度器只要大概 20 或 25 步推理. 使用 [`ConfigMixin.from_config`] 方法加载新的调度器:
-
-```python
-from diffusers import DPMSolverMultistepScheduler
-
-pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
-```
-
-现在将 `num_inference_steps` 设置为 20:
-
-```python
-generator = torch.Generator("cuda").manual_seed(0)
-image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
-image
-```
-
-
-

-
-
-太棒了!你成功把推理时间缩短到 4 秒!⚡️
-
-## 内存
-
-改善 pipeline 性能的另一个关键是减少内存的使用量,这间接意味着速度更快,因为你经常试图最大化每秒生成的图像数量。要想知道你一次可以生成多少张图片,最简单的方法是尝试不同的batch size,直到出现`OutOfMemoryError` (OOM)。
-
-创建一个函数,为每一批要生成的图像分配提示词和 `Generators` 。请务必为每个`Generator` 分配一个种子,以便于复现良好的结果。
-
-
-```python
-def get_inputs(batch_size=1):
- generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
- prompts = batch_size * [prompt]
- num_inference_steps = 20
-
- return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
-```
-
-设置 `batch_size=4` ,然后看一看我们消耗了多少内存:
-
-```python
-from diffusers.utils import make_image_grid
-
-images = pipeline(**get_inputs(batch_size=4)).images
-make_image_grid(images, 2, 2)
-```
-
-除非你有一个更大内存的GPU, 否则上述代码会返回 `OOM` 错误! 大部分内存被 cross-attention 层使用。按顺序运行可以节省大量内存,而不是在批处理中进行。你可以为 pipeline 配置 [`~DiffusionPipeline.enable_attention_slicing`] 函数:
-
-```python
-pipeline.enable_attention_slicing()
-```
-
-现在尝试把 `batch_size` 增加到 8!
-
-```python
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-

-
-
-以前你不能一批生成 4 张图片,而现在你可以在一张图片里面生成八张图片而只需要大概3.5秒!这可能是 T4 GPU 在不牺牲质量的情况运行速度最快的一种方法。
-
-## 质量
-
-在最后两节中, 你要学习如何通过 `fp16` 来优化 pipeline 的速度, 通过使用性能更高的调度器来减少推理步数, 使用注意力切片(*enabling attention slicing*)方法来节省内存。现在,你将关注的是如何提高图像的质量。
-
-### 更好的 checkpoints
-
-有个显而易见的方法是使用更好的 checkpoints。 Stable Diffusion 模型是一个很好的起点, 自正式发布以来,还发布了几个改进版本。然而, 使用更新的版本并不意味着你会得到更好的结果。你仍然需要尝试不同的 checkpoints ,并做一些研究 (例如使用 [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/)) 来获得更好的结果。
-
-随着该领域的发展, 有越来越多经过微调的高质量的 checkpoints 用来生成不一样的风格. 在 [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) 和 [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) 寻找你感兴趣的一种!
-
-### 更好的 pipeline 组件
-
-也可以尝试用新版本替换当前 pipeline 组件。让我们加载最新的 [autodecoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae) 从 Stability AI 加载到 pipeline, 并生成一些图像:
-
-```python
-from diffusers import AutoencoderKL
-
-vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
-pipeline.vae = vae
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-

-
-
-### 更好的提示词工程
-
-用于生成图像的文本非常重要, 因此被称为 *提示词工程*。 在设计提示词工程应注意如下事项:
-
-- 我想生成的图像或类似图像如何存储在互联网上?
-- 我可以提供哪些额外的细节来引导模型朝着我想要的风格生成?
-
-考虑到这一点,让我们改进提示词,以包含颜色和更高质量的细节:
-
-```python
-prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes"
-prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta"
-```
-
-使用新的提示词生成一批图像:
-
-```python
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-

-
-
-非常的令人印象深刻! Let's tweak the second image - 把 `Generator` 的种子设置为 `1` - 添加一些关于年龄的主题文本:
-
-```python
-prompts = [
- "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
-]
-
-generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
-images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
-make_image_grid(images, 2, 2)
-```
-
-
-

-
-
-## 最后
-
-在本教程中, 您学习了如何优化[`DiffusionPipeline`]以提高计算和内存效率,以及提高生成输出的质量. 如果你有兴趣让你的 pipeline 更快, 可以看一看以下资源:
-
-- 学习 [PyTorch 2.0](./optimization/torch2.0) 和 [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) 可以让推理速度提高 5 - 300% . 在 A100 GPU 上, 推理速度可以提高 50% !
-- 如果你没法用 PyTorch 2, 我们建议你安装 [xFormers](./optimization/xformers)。它的内存高效注意力机制(*memory-efficient attention mechanism*)与PyTorch 1.13.1配合使用,速度更快,内存消耗更少。
-- 其他的优化技术, 如:模型卸载(*model offloading*), 包含在 [这份指南](./optimization/fp16).
+
+
+# 有效且高效的扩散
+
+[[open-in-colab]]
+
+让 [`DiffusionPipeline`] 生成特定风格或包含你所想要的内容的图像可能会有些棘手。 通常情况下,你需要多次运行 [`DiffusionPipeline`] 才能得到满意的图像。但是从无到有生成图像是一个计算密集的过程,特别是如果你要一遍又一遍地进行推理运算。
+
+这就是为什么从pipeline中获得最高的 *computational* (speed) 和 *memory* (GPU RAM) 非常重要 ,以减少推理周期之间的时间,从而使迭代速度更快。
+
+
+本教程将指导您如何通过 [`DiffusionPipeline`] 更快、更好地生成图像。
+
+
+首先,加载 [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) 模型:
+
+```python
+from diffusers import DiffusionPipeline
+
+model_id = "runwayml/stable-diffusion-v1-5"
+pipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
+```
+
+本教程将使用的提示词是 [`portrait photo of a old warrior chief`] ,但是你可以随心所欲的想象和构造自己的提示词:
+
+```python
+prompt = "portrait photo of a old warrior chief"
+```
+
+## 速度
+
+
+
+💡 如果你没有 GPU, 你可以从像 [Colab](https://colab.research.google.com/) 这样的 GPU 提供商获取免费的 GPU !
+
+
+
+加速推理的最简单方法之一是将 pipeline 放在 GPU 上 ,就像使用任何 PyTorch 模块一样:
+
+```python
+pipeline = pipeline.to("cuda")
+```
+
+为了确保您可以使用相同的图像并对其进行改进,使用 [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) 方法,然后设置一个随机数种子 以确保其 [复现性](./using-diffusers/reproducibility):
+
+```python
+import torch
+
+generator = torch.Generator("cuda").manual_seed(0)
+```
+
+现在,你可以生成一个图像:
+
+```python
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+

+
+
+在 T4 GPU 上,这个过程大概要30秒(如果你的 GPU 比 T4 好,可能会更快)。在默认情况下,[`DiffusionPipeline`] 使用完整的 `float32` 精度进行 50 步推理。你可以通过降低精度(如 `float16` )或者减少推理步数来加速整个过程
+
+
+让我们把模型的精度降低至 `float16` ,然后生成一张图像:
+
+```python
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)
+pipeline = pipeline.to("cuda")
+generator = torch.Generator("cuda").manual_seed(0)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+

+
+
+这一次,生成图像只花了约 11 秒,比之前快了近 3 倍!
+
+
+
+💡 我们强烈建议把 pipeline 精度降低至 `float16` , 到目前为止, 我们很少看到输出质量有任何下降。
+
+
+
+另一个选择是减少推理步数。 你可以选择一个更高效的调度器 (*scheduler*) 可以减少推理步数同时保证输出质量。您可以在 [DiffusionPipeline] 中通过调用compatibles方法找到与当前模型兼容的调度器 (*scheduler*)。
+
+```python
+pipeline.scheduler.compatibles
+[
+ diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
+ diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
+ diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
+ diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
+ diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
+ diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
+ diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,
+ diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
+ diffusers.schedulers.scheduling_pndm.PNDMScheduler,
+ diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
+ diffusers.schedulers.scheduling_ddim.DDIMScheduler,
+]
+```
+
+Stable Diffusion 模型默认使用的是 [`PNDMScheduler`] ,通常要大概50步推理, 但是像 [`DPMSolverMultistepScheduler`] 这样更高效的调度器只要大概 20 或 25 步推理. 使用 [`ConfigMixin.from_config`] 方法加载新的调度器:
+
+```python
+from diffusers import DPMSolverMultistepScheduler
+
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+```
+
+现在将 `num_inference_steps` 设置为 20:
+
+```python
+generator = torch.Generator("cuda").manual_seed(0)
+image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+

+
+
+太棒了!你成功把推理时间缩短到 4 秒!⚡️
+
+## 内存
+
+改善 pipeline 性能的另一个关键是减少内存的使用量,这间接意味着速度更快,因为你经常试图最大化每秒生成的图像数量。要想知道你一次可以生成多少张图片,最简单的方法是尝试不同的batch size,直到出现`OutOfMemoryError` (OOM)。
+
+创建一个函数,为每一批要生成的图像分配提示词和 `Generators` 。请务必为每个`Generator` 分配一个种子,以便于复现良好的结果。
+
+
+```python
+def get_inputs(batch_size=1):
+ generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
+ prompts = batch_size * [prompt]
+ num_inference_steps = 20
+
+ return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
+```
+
+设置 `batch_size=4` ,然后看一看我们消耗了多少内存:
+
+```python
+from diffusers.utils import make_image_grid
+
+images = pipeline(**get_inputs(batch_size=4)).images
+make_image_grid(images, 2, 2)
+```
+
+除非你有一个更大内存的GPU, 否则上述代码会返回 `OOM` 错误! 大部分内存被 cross-attention 层使用。按顺序运行可以节省大量内存,而不是在批处理中进行。你可以为 pipeline 配置 [`~DiffusionPipeline.enable_attention_slicing`] 函数:
+
+```python
+pipeline.enable_attention_slicing()
+```
+
+现在尝试把 `batch_size` 增加到 8!
+
+```python
+images = pipeline(**get_inputs(batch_size=8)).images
+make_image_grid(images, rows=2, cols=4)
+```
+
+
+

+
+
+以前你不能一批生成 4 张图片,而现在你可以在一张图片里面生成八张图片而只需要大概3.5秒!这可能是 T4 GPU 在不牺牲质量的情况运行速度最快的一种方法。
+
+## 质量
+
+在最后两节中, 你要学习如何通过 `fp16` 来优化 pipeline 的速度, 通过使用性能更高的调度器来减少推理步数, 使用注意力切片(*enabling attention slicing*)方法来节省内存。现在,你将关注的是如何提高图像的质量。
+
+### 更好的 checkpoints
+
+有个显而易见的方法是使用更好的 checkpoints。 Stable Diffusion 模型是一个很好的起点, 自正式发布以来,还发布了几个改进版本。然而, 使用更新的版本并不意味着你会得到更好的结果。你仍然需要尝试不同的 checkpoints ,并做一些研究 (例如使用 [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/)) 来获得更好的结果。
+
+随着该领域的发展, 有越来越多经过微调的高质量的 checkpoints 用来生成不一样的风格. 在 [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) 和 [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) 寻找你感兴趣的一种!
+
+### 更好的 pipeline 组件
+
+也可以尝试用新版本替换当前 pipeline 组件。让我们加载最新的 [autodecoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae) 从 Stability AI 加载到 pipeline, 并生成一些图像:
+
+```python
+from diffusers import AutoencoderKL
+
+vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
+pipeline.vae = vae
+images = pipeline(**get_inputs(batch_size=8)).images
+make_image_grid(images, rows=2, cols=4)
+```
+
+
+

+
+
+### 更好的提示词工程
+
+用于生成图像的文本非常重要, 因此被称为 *提示词工程*。 在设计提示词工程应注意如下事项:
+
+- 我想生成的图像或类似图像如何存储在互联网上?
+- 我可以提供哪些额外的细节来引导模型朝着我想要的风格生成?
+
+考虑到这一点,让我们改进提示词,以包含颜色和更高质量的细节:
+
+```python
+prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes"
+prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta"
+```
+
+使用新的提示词生成一批图像:
+
+```python
+images = pipeline(**get_inputs(batch_size=8)).images
+make_image_grid(images, rows=2, cols=4)
+```
+
+
+

+
+
+非常的令人印象深刻! Let's tweak the second image - 把 `Generator` 的种子设置为 `1` - 添加一些关于年龄的主题文本:
+
+```python
+prompts = [
+ "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+]
+
+generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
+images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
+make_image_grid(images, 2, 2)
+```
+
+
+

+
+
+## 最后
+
+在本教程中, 您学习了如何优化[`DiffusionPipeline`]以提高计算和内存效率,以及提高生成输出的质量. 如果你有兴趣让你的 pipeline 更快, 可以看一看以下资源:
+
+- 学习 [PyTorch 2.0](./optimization/torch2.0) 和 [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) 可以让推理速度提高 5 - 300% . 在 A100 GPU 上, 推理速度可以提高 50% !
+- 如果你没法用 PyTorch 2, 我们建议你安装 [xFormers](./optimization/xformers)。它的内存高效注意力机制(*memory-efficient attention mechanism*)与PyTorch 1.13.1配合使用,速度更快,内存消耗更少。
+- 其他的优化技术, 如:模型卸载(*model offloading*), 包含在 [这份指南](./optimization/fp16).
diff --git a/examples/README.md b/examples/README.md
index f0d8a6bb57f0..9566e68fc51d 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -19,7 +19,7 @@ Diffusers examples are a collection of scripts to demonstrate how to effectively
for a variety of use cases involving training or fine-tuning.
**Note**: If you are looking for **official** examples on how to use `diffusers` for inference,
-please have a look at [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
+please have a look at [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines)
Our examples aspire to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**.
More specifically, this means:
diff --git a/examples/community/README.md b/examples/community/README.md
index aee6ffee09c7..51ce59edec6c 100755
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -8,7 +8,6 @@ If a community doesn't work as expected, please open an issue and ping the autho
| Example | Description | Code Example | Colab | Author |
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
-| LLM-grounded Diffusion (LMD+) | LMD greatly improves the prompt following ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion) | [LLM-grounded Diffusion (LMD+)](#llm-grounded-diffusion) | [Huggingface Demo](https://huggingface.co/spaces/longlian/llm-grounded-diffusion) [](https://colab.research.google.com/drive/1SXzMSeAB-LJYISb2yrUOdypLz4OYWUKj) | [Long (Tony) Lian](https://tonylian.com/) |
| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) |
| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) |
@@ -42,14 +41,10 @@ If a community doesn't work as expected, please open an issue and ping the autho
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#Zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | - | [Andrew Zhu](https://xhinker.medium.com/) |
-FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
+FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
-| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
-| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
-| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
-| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
-|
+
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
@@ -58,82 +53,6 @@ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custo
## Example usages
-### LLM-grounded Diffusion
-
-LMD and LMD+ greatly improves the prompt understanding ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. It improves spatial reasoning, the understanding of negation, attribute binding, generative numeracy, etc. in a unified manner without explicitly aiming for each. LMD is completely training-free (i.e., uses SD model off-the-shelf). LMD+ takes in additional adapters for better control. This is a reproduction of LMD+ model used in our work. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion)
-
-
-
-
-This pipeline can be used with an LLM or on its own. We provide a parser that parses LLM outputs to the layouts. You can obtain the prompt to input to the LLM for layout generation [here](https://github.com/TonyLianLong/LLM-groundedDiffusion/blob/main/prompt.py). After feeding the prompt to an LLM (e.g., GPT-4 on ChatGPT website), you can feed the LLM response into our pipeline.
-
-The following code has been tested on 1x RTX 4090, but it should also support GPUs with lower GPU memory.
-
-#### Use this pipeline with an LLM
-```python
-import torch
-from diffusers import DiffusionPipeline
-
-pipe = DiffusionPipeline.from_pretrained(
- "longlian/lmd_plus",
- custom_pipeline="llm_grounded_diffusion",
- variant="fp16", torch_dtype=torch.float16
-)
-pipe.enable_model_cpu_offload()
-
-# Generate directly from a text prompt and an LLM response
-prompt = "a waterfall and a modern high speed train in a beautiful forest with fall foliage"
-phrases, boxes, bg_prompt, neg_prompt = pipe.parse_llm_response("""
-[('a waterfall', [71, 105, 148, 258]), ('a modern high speed train', [255, 223, 181, 149])]
-Background prompt: A beautiful forest with fall foliage
-Negative prompt:
-""")
-
-images = pipe(
- prompt=prompt,
- negative_prompt=neg_prompt,
- phrases=phrases,
- boxes=boxes,
- gligen_scheduled_sampling_beta=0.4,
- output_type="pil",
- num_inference_steps=50,
- lmd_guidance_kwargs={}
-).images
-
-images[0].save("./lmd_plus_generation.jpg")
-```
-
-#### Use this pipeline on its own for layout generation
-```python
-import torch
-from diffusers import DiffusionPipeline
-
-pipe = DiffusionPipeline.from_pretrained(
- "longlian/lmd_plus",
- custom_pipeline="llm_grounded_diffusion",
- variant="fp16", torch_dtype=torch.float16
-)
-pipe.enable_model_cpu_offload()
-
-# Generate an image described by the prompt and
-# insert objects described by text at the region defined by bounding boxes
-prompt = "a waterfall and a modern high speed train in a beautiful forest with fall foliage"
-boxes = [[0.1387, 0.2051, 0.4277, 0.7090], [0.4980, 0.4355, 0.8516, 0.7266]]
-phrases = ["a waterfall", "a modern high speed train"]
-
-images = pipe(
- prompt=prompt,
- phrases=phrases,
- boxes=boxes,
- gligen_scheduled_sampling_beta=0.4,
- output_type="pil",
- num_inference_steps=50,
- lmd_guidance_kwargs={}
-).images
-
-images[0].save("./lmd_plus_generation.jpg")
-```
-
### CLIP Guided Stable Diffusion
CLIP guided stable diffusion can help to generate more realistic images
@@ -846,7 +765,7 @@ pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom
#There are multiple possible scenarios:
#The pipeline with the merged checkpoints is returned in all the scenarios
-#Compatible checkpoints a.k.a matched model_index.json files. Ignores the meta attributes in model_index.json during comparison.( attrs with _ as prefix )
+#Compatible checkpoints a.k.a matched model_index.json files. Ignores the meta attributes in model_index.json during comparision.( attrs with _ as prefix )
merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","CompVis/stable-diffusion-v1-2"], interp = "sigmoid", alpha = 0.4)
#Incompatible checkpoints in model_index.json but merge might be possible. Use force = True to ignore model_index.json compatibility
@@ -1610,14 +1529,14 @@ print("Latency of StableDiffusionPipeline--fp32",latency)

-CLIP guided stable diffusion images mixing pipeline allows to combine two images using standard diffusion models.
+CLIP guided stable diffusion images mixing pipline allows to combine two images using standard diffusion models.
This approach is using (optional) CoCa model to avoid writing image description.
[More code examples](https://github.com/TheDenk/images_mixing)
### Stable Diffusion XL Long Weighted Prompt Pipeline
-This SDXL pipeline support unlimited length prompt and negative prompt, compatible with A1111 prompt weighted style.
+This SDXL pipeline support unlimted length prompt and negative prompt, compatible with A1111 prompt weighted style.
You can provide both `prompt` and `prompt_2`. if only one prompt is provided, `prompt_2` will be a copy of the provided `prompt`. Here is a sample code to use this pipeline.
@@ -1686,7 +1605,7 @@ coca_transform = open_clip.image_transform(
)
coca_tokenizer = SimpleTokenizer()
-# Pipeline creating
+# Pipline creating
mixing_pipeline = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
custom_pipeline="clip_guided_images_mixing_stable_diffusion",
@@ -1700,7 +1619,7 @@ mixing_pipeline = DiffusionPipeline.from_pretrained(
mixing_pipeline.enable_attention_slicing()
mixing_pipeline = mixing_pipeline.to("cuda")
-# Pipeline running
+# Pipline running
generator = torch.Generator(device="cuda").manual_seed(17)
def download_image(url):
@@ -2053,7 +1972,7 @@ import torch
from PIL import Image
from io import BytesIO
-from diffusers import DiffusionPipeline
+from diffusers import Diffusionpipeline
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
@@ -2228,440 +2147,3 @@ edit_kcross_attention_kwargswargs = {
```
Side note: See [this GitHub gist](https://gist.github.com/UmerHA/b65bb5fb9626c9c73f3ade2869e36164) if you want to visualize the attention maps.
-
-### Latent Consistency Pipeline
-
-Latent Consistency Models was proposed in [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) by *Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, Hang Zhao* from Tsinghua University.
-
-The abstract of the paper reads as follows:
-
-*Latent Diffusion models (LDMs) have achieved remarkable results in synthesizing high-resolution images. However, the iterative sampling process is computationally intensive and leads to slow generation. Inspired by Consistency Models (song et al.), we propose Latent Consistency Models (LCMs), enabling swift inference with minimal steps on any pre-trained LDMs, including Stable Diffusion (rombach et al). Viewing the guided reverse diffusion process as solving an augmented probability flow ODE (PF-ODE), LCMs are designed to directly predict the solution of such ODE in latent space, mitigating the need for numerous iterations and allowing rapid, high-fidelity sampling. Efficiently distilled from pre-trained classifier-free guided diffusion models, a high-quality 768 x 768 2~4-step LCM takes only 32 A100 GPU hours for training. Furthermore, we introduce Latent Consistency Fine-tuning (LCF), a novel method that is tailored for fine-tuning LCMs on customized image datasets. Evaluation on the LAION-5B-Aesthetics dataset demonstrates that LCMs achieve state-of-the-art text-to-image generation performance with few-step inference. Project Page: [this https URL](https://latent-consistency-models.github.io/)*
-
-The model can be used with `diffusers` as follows:
-
- - *1. Load the model from the community pipeline.*
-
-```py
-from diffusers import DiffusionPipeline
-import torch
-
-pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_txt2img", custom_revision="main")
-
-# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
-pipe.to(torch_device="cuda", torch_dtype=torch.float32)
-```
-
-- 2. Run inference with as little as 4 steps:
-
-```py
-prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
-
-# Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
-num_inference_steps = 4
-
-images = pipe(prompt=prompt, num_inference_steps=num_inference_steps, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil").images
-```
-
-For any questions or feedback, feel free to reach out to [Simian Luo](https://github.com/luosiallen).
-
-You can also try this pipeline directly in the [🚀 official spaces](https://huggingface.co/spaces/SimianLuo/Latent_Consistency_Model).
-
-
-
-### Latent Consistency Img2img Pipeline
-
-This pipeline extends the Latent Consistency Pipeline to allow it to take an input image.
-
-```py
-from diffusers import DiffusionPipeline
-import torch
-
-pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_img2img")
-
-# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
-pipe.to(torch_device="cuda", torch_dtype=torch.float32)
-```
-
-- 2. Run inference with as little as 4 steps:
-
-```py
-prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
-
-
-input_image=Image.open("myimg.png")
-
-strength = 0.5 #strength =0 (no change) strength=1 (completely overwrite image)
-
-# Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
-num_inference_steps = 4
-
-images = pipe(prompt=prompt, image=input_image, strength=strength, num_inference_steps=num_inference_steps, guidance_scale=8.0, lcm_origin_steps=50, output_type="pil").images
-```
-
-
-
-### Latent Consistency Interpolation Pipeline
-
-This pipeline extends the Latent Consistency Pipeline to allow for interpolation of the latent space between multiple prompts. It is similar to the [Stable Diffusion Interpolate](https://github.com/huggingface/diffusers/blob/main/examples/community/interpolate_stable_diffusion.py) and [unCLIP Interpolate](https://github.com/huggingface/diffusers/blob/main/examples/community/unclip_text_interpolation.py) community pipelines.
-
-```py
-import torch
-import numpy as np
-
-from diffusers import DiffusionPipeline
-
-pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", custom_pipeline="latent_consistency_interpolate")
-
-# To save GPU memory, torch.float16 can be used, but it may compromise image quality.
-pipe.to(torch_device="cuda", torch_dtype=torch.float32)
-
-prompts = [
- "Self-portrait oil painting, a beautiful cyborg with golden hair, Margot Robbie, 8k",
- "Self-portrait oil painting, an extremely strong man, body builder, Huge Jackman, 8k",
- "An astronaut floating in space, renaissance art, realistic, high quality, 8k",
- "Oil painting of a cat, cute, dream-like",
- "Hugging face emoji, cute, realistic"
-]
-num_inference_steps = 4
-num_interpolation_steps = 60
-seed = 1337
-
-torch.manual_seed(seed)
-np.random.seed(seed)
-
-images = pipe(
- prompt=prompts,
- height=512,
- width=512,
- num_inference_steps=num_inference_steps,
- num_interpolation_steps=num_interpolation_steps,
- guidance_scale=8.0,
- embedding_interpolation_type="lerp",
- latent_interpolation_type="slerp",
- process_batch_size=4, # Make it higher or lower based on your GPU memory
- generator=torch.Generator(seed),
-)
-
-assert len(images) == (len(prompts) - 1) * num_interpolation_steps
-```
-
-### StableDiffusionUpscaleLDM3D Pipeline
-[LDM3D-VR](https://arxiv.org/pdf/2311.03226.pdf) is an extended version of LDM3D.
-
-The abstract from the paper is:
-*Latent diffusion models have proven to be state-of-the-art in the creation and manipulation of visual outputs. However, as far as we know, the generation of depth maps jointly with RGB is still limited. We introduce LDM3D-VR, a suite of diffusion models targeting virtual reality development that includes LDM3D-pano and LDM3D-SR. These models enable the generation of panoramic RGBD based on textual prompts and the upscaling of low-resolution inputs to high-resolution RGBD, respectively. Our models are fine-tuned from existing pretrained models on datasets containing panoramic/high-resolution RGB images, depth maps and captions. Both models are evaluated in comparison to existing related methods*
-
-Two checkpoints are available for use:
-- [ldm3d-pano](https://huggingface.co/Intel/ldm3d-pano). This checkpoint enables the generation of panoramic images and requires the StableDiffusionLDM3DPipeline pipeline to be used.
-- [ldm3d-sr](https://huggingface.co/Intel/ldm3d-sr). This checkpoint enables the upscaling of RGB and depth images. Can be used in cascade after the original LDM3D pipeline using the StableDiffusionUpscaleLDM3DPipeline pipeline.
-
-'''py
-from PIL import Image
-import os
-import torch
-from diffusers import StableDiffusionLDM3DPipeline, DiffusionPipeline
-
-#Generate a rgb/depth output from LDM3D
-pipe_ldm3d = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c")
-pipe_ldm3d.to("cuda")
-
-prompt =f"A picture of some lemons on a table"
-output = pipe_ldm3d(prompt)
-rgb_image, depth_image = output.rgb, output.depth
-rgb_image[0].save(f"lemons_ldm3d_rgb.jpg")
-depth_image[0].save(f"lemons_ldm3d_depth.png")
-
-
-#Upscale the previous output to a resolution of (1024, 1024)
-pipe_ldm3d_upscale = DiffusionPipeline.from_pretrained("Intel/ldm3d-sr", custom_pipeline="pipeline_stable_diffusion_upscale_ldm3d")
-
-pipe_ldm3d_upscale.to("cuda")
-
-low_res_img = Image.open(f"lemons_ldm3d_rgb.jpg").convert("RGB")
-low_res_depth = Image.open(f"lemons_ldm3d_depth.png").convert("L")
-outputs = pipe_ldm3d_upscale(prompt="high quality high resolution uhd 4k image", rgb=low_res_img, depth=low_res_depth, num_inference_steps=50, target_res=[1024, 1024])
-
-upscaled_rgb, upscaled_depth =outputs.rgb[0], outputs.depth[0]
-upscaled_rgb.save(f"upscaled_lemons_rgb.png")
-upscaled_depth.save(f"upscaled_lemons_depth.png")
-'''
-
-### ControlNet + T2I Adapter Pipeline
-This pipelines combines both ControlNet and T2IAdapter into a single pipeline, where the forward pass is executed once.
-It receives `control_image` and `adapter_image`, as well as `controlnet_conditioning_scale` and `adapter_conditioning_scale`, for the ControlNet and Adapter modules, respectively. Whenever `adapter_conditioning_scale = 0` or `controlnet_conditioning_scale = 0`, it will act as a full ControlNet module or as a full T2IAdapter module, respectively.
-
-```py
-import cv2
-import numpy as np
-import torch
-from controlnet_aux.midas import MidasDetector
-from PIL import Image
-
-from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
-from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
-from diffusers.utils import load_image
-from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter import (
- StableDiffusionXLControlNetAdapterPipeline,
-)
-
-controlnet_depth = ControlNetModel.from_pretrained(
- "diffusers/controlnet-depth-sdxl-1.0",
- torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True
-)
-adapter_depth = T2IAdapter.from_pretrained(
- "TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
-vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
-
-pipe = StableDiffusionXLControlNetAdapterPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- controlnet=controlnet_depth,
- adapter=adapter_depth,
- vae=vae,
- variant="fp16",
- use_safetensors=True,
- torch_dtype=torch.float16,
-)
-pipe = pipe.to("cuda")
-pipe.enable_xformers_memory_efficient_attention()
-# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
-midas_depth = MidasDetector.from_pretrained(
- "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
-).to("cuda")
-
-prompt = "a tiger sitting on a park bench"
-img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
-
-image = load_image(img_url).resize((1024, 1024))
-
-depth_image = midas_depth(
- image, detect_resolution=512, image_resolution=1024
-)
-
-strength = 0.5
-
-images = pipe(
- prompt,
- control_image=depth_image,
- adapter_image=depth_image,
- num_inference_steps=30,
- controlnet_conditioning_scale=strength,
- adapter_conditioning_scale=strength,
-).images
-images[0].save("controlnet_and_adapter.png")
-
-```
-
-### ControlNet + T2I Adapter + Inpainting Pipeline
-```py
-import cv2
-import numpy as np
-import torch
-from controlnet_aux.midas import MidasDetector
-from PIL import Image
-
-from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
-from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
-from diffusers.utils import load_image
-from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter_inpaint import (
- StableDiffusionXLControlNetAdapterInpaintPipeline,
-)
-
-controlnet_depth = ControlNetModel.from_pretrained(
- "diffusers/controlnet-depth-sdxl-1.0",
- torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True
-)
-adapter_depth = T2IAdapter.from_pretrained(
- "TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
-)
-vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
-
-pipe = StableDiffusionXLControlNetAdapterInpaintPipeline.from_pretrained(
- "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
- controlnet=controlnet_depth,
- adapter=adapter_depth,
- vae=vae,
- variant="fp16",
- use_safetensors=True,
- torch_dtype=torch.float16,
-)
-pipe = pipe.to("cuda")
-pipe.enable_xformers_memory_efficient_attention()
-# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
-midas_depth = MidasDetector.from_pretrained(
- "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
-).to("cuda")
-
-prompt = "a tiger sitting on a park bench"
-img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
-mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
-
-image = load_image(img_url).resize((1024, 1024))
-mask_image = load_image(mask_url).resize((1024, 1024))
-
-depth_image = midas_depth(
- image, detect_resolution=512, image_resolution=1024
-)
-
-strength = 0.4
-
-images = pipe(
- prompt,
- image=image,
- mask_image=mask_image,
- control_image=depth_image,
- adapter_image=depth_image,
- num_inference_steps=30,
- controlnet_conditioning_scale=strength,
- adapter_conditioning_scale=strength,
- strength=0.7,
-).images
-images[0].save("controlnet_and_adapter_inpaint.png")
-
-```
-
-## Diffusion Posterior Sampling Pipeline
-* Reference paper
- ```
- @article{chung2022diffusion,
- title={Diffusion posterior sampling for general noisy inverse problems},
- author={Chung, Hyungjin and Kim, Jeongsol and Mccann, Michael T and Klasky, Marc L and Ye, Jong Chul},
- journal={arXiv preprint arXiv:2209.14687},
- year={2022}
- }
- ```
-* This pipeline allows zero-shot conditional sampling from the posterior distribution $p(x|y)$, given observation on $y$, unconditional generative model $p(x)$ and differentiable operator $y=f(x)$.
-* For example, $f(.)$ can be downsample operator, then $y$ is a downsampled image, and the pipeline becomes a super-resolution pipeline.
-* To use this pipeline, you need to know your operator $f(.)$ and corrupted image $y$, and pass them during the call. For example, as in the main function of dps_pipeline.py, you need to first define the Gaussian blurring operator $f(.)$. The operator should be a callable nn.Module, with all the parameter gradient disabled:
- ```python
- import torch.nn.functional as F
- import scipy
- from torch import nn
-
- # define the Gaussian blurring operator first
- class GaussialBlurOperator(nn.Module):
- def __init__(self, kernel_size, intensity):
- super().__init__()
-
- class Blurkernel(nn.Module):
- def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0):
- super().__init__()
- self.blur_type = blur_type
- self.kernel_size = kernel_size
- self.std = std
- self.seq = nn.Sequential(
- nn.ReflectionPad2d(self.kernel_size//2),
- nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
- )
- self.weights_init()
-
- def forward(self, x):
- return self.seq(x)
-
- def weights_init(self):
- if self.blur_type == "gaussian":
- n = np.zeros((self.kernel_size, self.kernel_size))
- n[self.kernel_size // 2,self.kernel_size // 2] = 1
- k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
- k = torch.from_numpy(k)
- self.k = k
- for name, f in self.named_parameters():
- f.data.copy_(k)
- elif self.blur_type == "motion":
- k = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
- k = torch.from_numpy(k)
- self.k = k
- for name, f in self.named_parameters():
- f.data.copy_(k)
-
- def update_weights(self, k):
- if not torch.is_tensor(k):
- k = torch.from_numpy(k)
- for name, f in self.named_parameters():
- f.data.copy_(k)
-
- def get_kernel(self):
- return self.k
-
- self.kernel_size = kernel_size
- self.conv = Blurkernel(blur_type='gaussian',
- kernel_size=kernel_size,
- std=intensity)
- self.kernel = self.conv.get_kernel()
- self.conv.update_weights(self.kernel.type(torch.float32))
-
- for param in self.parameters():
- param.requires_grad=False
-
- def forward(self, data, **kwargs):
- return self.conv(data)
-
- def transpose(self, data, **kwargs):
- return data
-
- def get_kernel(self):
- return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
- ```
-* Next, you should obtain the corrupted image $y$ by the operator. In this example, we generate $y$ from the source image $x$. However in practice, having the operator $f(.)$ and corrupted image $y$ is enough:
- ```python
- # set up source image
- src = Image.open('sample.png')
- # read image into [1,3,H,W]
- src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2,0,1)[None]
- # normalize image to [-1,1]
- src = (src / 127.5) - 1.0
- src = src.to("cuda")
-
- # set up operator and measurement
- operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda")
- measurement = operator(src)
-
- # save the source and corrupted images
- save_image((src+1.0)/2.0, "dps_src.png")
- save_image((measurement+1.0)/2.0, "dps_mea.png")
- ```
-* We provide an example pair of saved source and corrupted images, using the Gaussian blur operator above
- * Source image:
- * 
- * Gaussian blurred image:
- * 
- * You can download those image to run the example on your own.
-* Next, we need to define a loss function used for diffusion posterior sample. For most of the cases, the RMSE is fine:
- ```python
- def RMSELoss(yhat, y):
- return torch.sqrt(torch.sum((yhat-y)**2))
- ```
-* And next, as any other diffusion models, we need the score estimator and scheduler. As we are working with $256x256$ face images, we use ddmp-celebahq-256:
- ```python
- # set up scheduler
- scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")
- scheduler.set_timesteps(1000)
-
- # set up model
- model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda")
- ```
-* And finally, run the pipeline:
- ```python
- # finally, the pipeline
- dpspipe = DPSPipeline(model, scheduler)
- image = dpspipe(
- measurement = measurement,
- operator = operator,
- loss_fn = RMSELoss,
- zeta = 1.0,
- ).images[0]
- image.save("dps_generated_image.png")
- ```
-* The zeta is a hyperparameter that is in range of $[0,1]$. It need to be tuned for best effect. By setting zeta=1, you should be able to have the reconstructed result:
- * Reconstructed image:
- * 
-* The reconstruction is perceptually similar to the source image, but different in details.
-* In dps_pipeline.py, we also provide a super-resolution example, which should produce:
- * Downsampled image:
- * 
- * Reconstructed image:
- * 
diff --git a/examples/community/checkpoint_merger.py b/examples/community/checkpoint_merger.py
index 10381020bf63..02e8684e6ade 100644
--- a/examples/community/checkpoint_merger.py
+++ b/examples/community/checkpoint_merger.py
@@ -13,7 +13,7 @@
class CheckpointMergerPipeline(DiffusionPipeline):
"""
- A class that supports merging diffusion models based on the discussion here:
+ A class that that supports merging diffusion models based on the discussion here:
https://github.com/huggingface/diffusers/issues/877
Example usage:-
diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py
index 444d3375c3d1..996bb3cef8bf 100644
--- a/examples/community/composable_stable_diffusion.py
+++ b/examples/community/composable_stable_diffusion.py
@@ -65,7 +65,6 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
-
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
diff --git a/examples/community/ddim_noise_comparative_analysis.py b/examples/community/ddim_noise_comparative_analysis.py
index 482c0a5826d2..e1633ce4636b 100644
--- a/examples/community/ddim_noise_comparative_analysis.py
+++ b/examples/community/ddim_noise_comparative_analysis.py
@@ -18,7 +18,7 @@
import torch
from torchvision import transforms
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.schedulers import DDIMScheduler
from diffusers.utils.torch_utils import randn_tensor
diff --git a/examples/community/iadb.py b/examples/community/iadb.py
index 6089e49fc621..1f421ee0ea4c 100644
--- a/examples/community/iadb.py
+++ b/examples/community/iadb.py
@@ -4,7 +4,7 @@
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
-from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
+from diffusers.pipeline_utils import ImagePipelineOutput
from diffusers.schedulers.scheduling_utils import SchedulerMixin
diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py
index 7249e033186f..ee0cdc461cf5 100644
--- a/examples/community/lpw_stable_diffusion.py
+++ b/examples/community/lpw_stable_diffusion.py
@@ -56,10 +56,10 @@ def parse_prompt_attention(text):
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
- \\( - literal character '('
- \\[ - literal character '['
- \\) - literal character ')'
- \\] - literal character ']'
+ \( - literal character '('
+ \[ - literal character '['
+ \) - literal character ')'
+ \] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
@@ -68,7 +68,7 @@ def parse_prompt_attention(text):
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
- >>> parse_prompt_attention('\\(literal\\]')
+ >>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py
index 87c2944dbc44..423e6ced4d77 100644
--- a/examples/community/lpw_stable_diffusion_onnx.py
+++ b/examples/community/lpw_stable_diffusion_onnx.py
@@ -82,10 +82,10 @@ def parse_prompt_attention(text):
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
- \\( - literal character '('
- \\[ - literal character '['
- \\) - literal character ')'
- \\] - literal character ']'
+ \( - literal character '('
+ \[ - literal character '['
+ \) - literal character ')'
+ \] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
@@ -94,7 +94,7 @@ def parse_prompt_attention(text):
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
- >>> parse_prompt_attention('\\(literal\\]')
+ >>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
@@ -433,7 +433,6 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
"""
-
if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
def __init__(
diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py
index dfe60d9794e1..66e2ffb159a1 100644
--- a/examples/community/lpw_stable_diffusion_xl.py
+++ b/examples/community/lpw_stable_diffusion_xl.py
@@ -46,10 +46,10 @@ def parse_prompt_attention(text):
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
- \\( - literal character '('
- \\[ - literal character '['
- \\) - literal character ')'
- \\] - literal character ']'
+ \( - literal character '('
+ \[ - literal character '['
+ \) - literal character ')'
+ \] - literal character ']'
\\ - literal character '\'
anything else - just text
@@ -59,7 +59,7 @@ def parse_prompt_attention(text):
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
- >>> parse_prompt_attention('\\(literal\\]')
+ >>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
@@ -249,8 +249,6 @@ def get_weighted_text_embeddings_sdxl(
prompt_2: str = None,
neg_prompt: str = "",
neg_prompt_2: str = None,
- num_images_per_prompt: int = 1,
- device: Optional[torch.device] = None,
):
"""
This function can process long prompt with weights, no length limitation
@@ -262,14 +260,10 @@ def get_weighted_text_embeddings_sdxl(
prompt_2 (str)
neg_prompt (str)
neg_prompt_2 (str)
- num_images_per_prompt (int)
- device (torch.device)
Returns:
prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
"""
- device = device or pipe._execution_device
-
if prompt_2:
prompt = f"{prompt} {prompt_2}"
@@ -334,17 +328,17 @@ def get_weighted_text_embeddings_sdxl(
# get prompt embeddings one by one is not working.
for i in range(len(prompt_token_groups)):
# get positive prompt embeddings with weights
- token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device)
- weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device)
+ token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
+ weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
- token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device)
+ token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
# use first text encoder
- prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)
+ prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True)
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
# use second text encoder
- prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)
+ prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True)
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
pooled_prompt_embeds = prompt_embeds_2[0]
@@ -361,16 +355,16 @@ def get_weighted_text_embeddings_sdxl(
embeds.append(token_embedding)
# get negative prompt embeddings with weights
- neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device)
- neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device)
- neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device)
+ neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
+ neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
+ neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
# use first text encoder
- neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True)
+ neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True)
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
# use second text encoder
- neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True)
+ neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True)
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
@@ -389,22 +383,6 @@ def get_weighted_text_embeddings_sdxl(
prompt_embeds = torch.cat(embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
- bs_embed, seq_len, _ = prompt_embeds.shape
- # duplicate text embeddings for each generation per prompt, using mps friendly method
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
-
- seq_len = negative_prompt_embeds.shape[1]
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
- negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
-
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
- bs_embed * num_images_per_prompt, -1
- )
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view(
- bs_embed * num_images_per_prompt, -1
- )
-
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
@@ -1118,9 +1096,7 @@ def __call__(
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
- ) = get_weighted_text_embeddings_sdxl(
- pipe=self, prompt=prompt, neg_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt
- )
+ ) = get_weighted_text_embeddings_sdxl(pipe=self, prompt=prompt, neg_prompt=negative_prompt)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
diff --git a/examples/community/magic_mix.py b/examples/community/magic_mix.py
index d3d118f84bfc..4eb99cb96b42 100644
--- a/examples/community/magic_mix.py
+++ b/examples/community/magic_mix.py
@@ -127,9 +127,9 @@ def __call__(
timesteps=t,
)
- input = (
- (mix_factor * latents) + (1 - mix_factor) * orig_latents
- ) # interpolating between layout noise and conditionally generated noise to preserve layout sematics
+ input = (mix_factor * latents) + (
+ 1 - mix_factor
+ ) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics
input = torch.cat([input] * 2)
else: # content generation phase
diff --git a/examples/community/mixture_canvas.py b/examples/community/mixture_canvas.py
index 3737183e5513..40139d1139ad 100644
--- a/examples/community/mixture_canvas.py
+++ b/examples/community/mixture_canvas.py
@@ -12,7 +12,7 @@
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
@@ -453,7 +453,9 @@ def __call__(
:,
region.latent_row_init : region.latent_row_end,
region.latent_col_init : region.latent_col_end,
- ] += noise_pred_region * mask_weights_region
+ ] += (
+ noise_pred_region * mask_weights_region
+ )
contributors[
:,
:,
diff --git a/examples/community/mixture_tiling.py b/examples/community/mixture_tiling.py
index f92ae0e1d359..3e701cf607f5 100644
--- a/examples/community/mixture_tiling.py
+++ b/examples/community/mixture_tiling.py
@@ -7,7 +7,7 @@
from tqdm.auto import tqdm
from diffusers.models import AutoencoderKL, UNet2DConditionModel
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils import logging
diff --git a/examples/community/pipeline_fabric.py b/examples/community/pipeline_fabric.py
index 080d0c221727..c5783402b36c 100644
--- a/examples/community/pipeline_fabric.py
+++ b/examples/community/pipeline_fabric.py
@@ -14,6 +14,7 @@
from typing import List, Optional, Union
import torch
+from diffuser.utils.torch_utils import randn_tensor
from packaging import version
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
@@ -32,7 +33,6 @@
logging,
replace_example_docstring,
)
-from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py
index 59b8e691bde3..7d330c668da9 100644
--- a/examples/community/pipeline_prompt2prompt.py
+++ b/examples/community/pipeline_prompt2prompt.py
@@ -65,7 +65,6 @@ class Prompt2PromptPipeline(StableDiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
-
_optional_components = ["safety_checker", "feature_extractor"]
@torch.no_grad()
diff --git a/examples/community/pipeline_zero1to3.py b/examples/community/pipeline_zero1to3.py
index 600cf2dc1b63..3e4e88ea5aa1 100644
--- a/examples/community/pipeline_zero1to3.py
+++ b/examples/community/pipeline_zero1to3.py
@@ -94,7 +94,6 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline):
cc_projection ([`CCProjection`]):
Projection layer to project the concated CLIP features and pose embeddings to the original CLIP feature size.
"""
-
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
@@ -659,8 +658,7 @@ def prepare_img_latents(self, image, batch_size, dtype, device, generator=None,
if isinstance(generator, list):
init_latents = [
- self.vae.encode(image[i : i + 1]).latent_dist.mode(generator[i])
- for i in range(batch_size) # sample
+ self.vae.encode(image[i : i + 1]).latent_dist.mode(generator[i]) for i in range(batch_size) # sample
]
init_latents = torch.cat(init_latents, dim=0)
else:
diff --git a/examples/community/run_onnx_controlnet.py b/examples/community/run_onnx_controlnet.py
index ed9b23318414..2b1123a4955c 100644
--- a/examples/community/run_onnx_controlnet.py
+++ b/examples/community/run_onnx_controlnet.py
@@ -553,7 +553,7 @@ def __call__(
instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
- The initial image will be used as the starting point for the image generation process. Can also accept
+ The initial image will be used as the starting point for the image generation process. Can also accpet
image latents as `image`, if passing latents directly, it will not be encoded again.
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
@@ -651,10 +651,9 @@ def __call__(
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = num_controlnet
- control_guidance_start, control_guidance_end = (
- mult * [control_guidance_start],
- mult * [control_guidance_end],
- )
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
+ control_guidance_end
+ ]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
diff --git a/examples/community/run_tensorrt_controlnet.py b/examples/community/run_tensorrt_controlnet.py
index aece5484e304..724f393eb122 100644
--- a/examples/community/run_tensorrt_controlnet.py
+++ b/examples/community/run_tensorrt_controlnet.py
@@ -657,7 +657,7 @@ def __call__(
instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
- The initial image will be used as the starting point for the image generation process. Can also accept
+ The initial image will be used as the starting point for the image generation process. Can also accpet
image latents as `image`, if passing latents directly, it will not be encoded again.
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
@@ -755,10 +755,9 @@ def __call__(
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = num_controlnet
- control_guidance_start, control_guidance_end = (
- mult * [control_guidance_start],
- mult * [control_guidance_end],
- )
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
+ control_guidance_end
+ ]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
diff --git a/examples/community/sd_text2img_k_diffusion.py b/examples/community/sd_text2img_k_diffusion.py
index 9371ac8819ed..b7fbc46b67cb 100755
--- a/examples/community/sd_text2img_k_diffusion.py
+++ b/examples/community/sd_text2img_k_diffusion.py
@@ -68,7 +68,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
-
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py
index a2b92fff0fb5..550aa8ba61a3 100644
--- a/examples/community/stable_diffusion_controlnet_img2img.py
+++ b/examples/community/stable_diffusion_controlnet_img2img.py
@@ -9,8 +9,8 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
-from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
PIL_INTERPOLATION,
diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py
index b87973366418..30903bbf66bf 100644
--- a/examples/community/stable_diffusion_controlnet_inpaint.py
+++ b/examples/community/stable_diffusion_controlnet_inpaint.py
@@ -10,8 +10,8 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
-from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
PIL_INTERPOLATION,
diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py
index 358fc1c6dc67..d786036bd58a 100644
--- a/examples/community/stable_diffusion_controlnet_reference.py
+++ b/examples/community/stable_diffusion_controlnet_reference.py
@@ -546,7 +546,7 @@ def hack_CrossAttnDownBlock2D_forward(
return hidden_states, output_states
- def hacked_DownBlock2D_forward(self, hidden_states, temb=None, *args, **kwargs):
+ def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
eps = 1e-6
output_states = ()
@@ -642,9 +642,7 @@ def hacked_CrossAttnUpBlock2D_forward(
return hidden_states
- def hacked_UpBlock2D_forward(
- self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, *args, **kwargs
- ):
+ def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
eps = 1e-6
for i, resnet in enumerate(self.resnets):
# pop res hidden states
diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py
index 6d86248acbe6..2f8131d6cbc0 100644
--- a/examples/community/stable_diffusion_ipex.py
+++ b/examples/community/stable_diffusion_ipex.py
@@ -21,9 +21,8 @@
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.configuration_utils import FrozenDict
-from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
@@ -62,7 +61,7 @@
"""
-class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
+class StableDiffusionIPEXPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion on IPEX.
@@ -89,7 +88,6 @@ class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
-
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
@@ -252,7 +250,9 @@ def prepare_for_ipex(self, promt, dtype=torch.float32, height=None, width=None,
# optimize with ipex
if dtype == torch.bfloat16:
- self.unet = ipex.optimize(self.unet.eval(), dtype=torch.bfloat16, inplace=True)
+ self.unet = ipex.optimize(
+ self.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=unet_input_example
+ )
self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True)
self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True)
if self.safety_checker is not None:
@@ -262,6 +262,8 @@ def prepare_for_ipex(self, promt, dtype=torch.float32, height=None, width=None,
self.unet.eval(),
dtype=torch.float32,
inplace=True,
+ sample_input=unet_input_example,
+ level="O1",
weights_prepack=True,
auto_kernel_selection=False,
)
@@ -269,6 +271,7 @@ def prepare_for_ipex(self, promt, dtype=torch.float32, height=None, width=None,
self.vae.decoder.eval(),
dtype=torch.float32,
inplace=True,
+ level="O1",
weights_prepack=True,
auto_kernel_selection=False,
)
@@ -276,6 +279,7 @@ def prepare_for_ipex(self, promt, dtype=torch.float32, height=None, width=None,
self.text_encoder.eval(),
dtype=torch.float32,
inplace=True,
+ level="O1",
weights_prepack=True,
auto_kernel_selection=False,
)
@@ -284,6 +288,7 @@ def prepare_for_ipex(self, promt, dtype=torch.float32, height=None, width=None,
self.safety_checker.eval(),
dtype=torch.float32,
inplace=True,
+ level="O1",
weights_prepack=True,
auto_kernel_selection=False,
)
@@ -449,10 +454,6 @@ def _encode_prompt(
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
- # textual inversion: procecss multi-vector tokens if necessary
- if isinstance(self, TextualInversionLoaderMixin):
- prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
-
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -513,10 +514,6 @@ def _encode_prompt(
else:
uncond_tokens = negative_prompt
- # textual inversion: procecss multi-vector tokens if necessary
- if isinstance(self, TextualInversionLoaderMixin):
- uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
-
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
diff --git a/examples/community/stable_diffusion_mega.py b/examples/community/stable_diffusion_mega.py
index faed00b49d40..0fec5557a637 100644
--- a/examples/community/stable_diffusion_mega.py
+++ b/examples/community/stable_diffusion_mega.py
@@ -50,7 +50,6 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
-
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
diff --git a/examples/community/stable_diffusion_repaint.py b/examples/community/stable_diffusion_repaint.py
index 4da46b370815..ce4f245b31fa 100644
--- a/examples/community/stable_diffusion_repaint.py
+++ b/examples/community/stable_diffusion_repaint.py
@@ -170,7 +170,6 @@ class StableDiffusionRepaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
-
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py
index 76975d79c1b3..d60fa19e8a7f 100644
--- a/examples/controlnet/train_controlnet.py
+++ b/examples/controlnet/train_controlnet.py
@@ -56,7 +56,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__)
@@ -86,7 +86,6 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
controlnet=controlnet,
safety_checker=None,
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -250,13 +249,10 @@ def parse_args(input_args=None):
type=str,
default=None,
required=False,
- help="Revision of pretrained model identifier from huggingface.co/models.",
- )
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ help=(
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
+ " float32 precision."
+ ),
)
parser.add_argument(
"--tokenizer_name",
@@ -771,13 +767,11 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
- )
- vae = AutoencoderKL.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
if args.controlnet_model_name_or_path:
diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py
index ba5f84fe2d2c..68162d7824ab 100644
--- a/examples/controlnet/train_controlnet_flax.py
+++ b/examples/controlnet/train_controlnet_flax.py
@@ -59,7 +59,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py
index d55ad9b2a834..04290885cf4b 100644
--- a/examples/controlnet/train_controlnet_sdxl.py
+++ b/examples/controlnet/train_controlnet_sdxl.py
@@ -58,7 +58,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__)
@@ -74,7 +74,6 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
unet=unet,
controlnet=controlnet,
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -244,18 +243,15 @@ def parse_args(input_args=None):
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
" If not specified controlnet weights are initialized from unet.",
)
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
- help="Revision of pretrained model identifier from huggingface.co/models.",
+ help=(
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
+ " float32 precision."
+ ),
)
parser.add_argument(
"--tokenizer_name",
@@ -797,16 +793,10 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer",
- revision=args.revision,
- use_fast=False,
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
)
tokenizer_two = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer_2",
- revision=args.revision,
- use_fast=False,
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
)
# import correct text encoder classes
@@ -820,10 +810,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
vae_path = (
args.pretrained_model_name_or_path
@@ -834,10 +824,9 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
- variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
if args.controlnet_model_name_or_path:
diff --git a/examples/custom_diffusion/README.md b/examples/custom_diffusion/README.md
index e686933feb51..9e3c387e3d34 100644
--- a/examples/custom_diffusion/README.md
+++ b/examples/custom_diffusion/README.md
@@ -48,7 +48,7 @@ write_basic_config()
Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it.
-We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
+We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training.
```bash
@@ -82,7 +82,7 @@ accelerate launch train_custom_diffusion.py \
**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.**
-To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (which we HIGHLY recommend), follow these steps:
+To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (whcih we HIGHLY recommend), follow these steps:
* Install `wandb`: `pip install wandb`.
* Authorize: `wandb login`.
diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py
index c6234738735f..4773446a615b 100644
--- a/examples/custom_diffusion/train_custom_diffusion.py
+++ b/examples/custom_diffusion/train_custom_diffusion.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
import argparse
+import hashlib
import itertools
import json
import logging
@@ -34,7 +35,6 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfApi, create_repo
-from huggingface_hub.utils import insecure_hashlib
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
@@ -62,7 +62,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__)
@@ -332,12 +332,6 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -746,7 +740,6 @@ def main(args):
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
- variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -767,7 +760,7 @@ def main(args):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
- hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = (
class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
)
@@ -808,13 +801,11 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
- )
- vae = AutoencoderKL.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# Adding a modifier token which is optimized ####
@@ -1238,7 +1229,6 @@ def main(args):
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -1288,7 +1278,7 @@ def main(args):
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
- args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
+ args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index 1700fe771d97..606cc5c6cfdd 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -16,6 +16,7 @@
import argparse
import copy
import gc
+import hashlib
import importlib
import itertools
import logging
@@ -34,7 +35,6 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, model_info, upload_folder
-from huggingface_hub.utils import insecure_hashlib
from packaging import version
from PIL import Image
from PIL.ImageOps import exif_transpose
@@ -61,7 +61,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__)
@@ -139,7 +139,6 @@ def log_validation(
text_encoder=text_encoder,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
**pipeline_args,
)
@@ -240,13 +239,10 @@ def parse_args(input_args=None):
type=str,
default=None,
required=False,
- help="Revision of pretrained model identifier from huggingface.co/models.",
- )
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ help=(
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
+ " float32 precision."
+ ),
)
parser.add_argument(
"--tokenizer_name",
@@ -300,7 +296,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--output_dir",
type=str,
- default="dreambooth-model",
+ default="text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
@@ -863,7 +859,6 @@ def main(args):
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
- variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -882,7 +877,7 @@ def main(args):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
- hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
@@ -917,18 +912,18 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
if model_has_vae(args):
vae = AutoencoderKL.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
)
else:
vae = None
unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -1124,7 +1119,7 @@ def compute_text_embeddings(prompt):
unet, optimizer, train_dataloader, lr_scheduler
)
- # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
@@ -1172,7 +1167,7 @@ def compute_text_embeddings(prompt):
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
- # Get the most recent checkpoint
+ # Get the mos recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
@@ -1369,7 +1364,7 @@ def compute_text_embeddings(prompt):
if global_step >= args.max_train_steps:
break
- # Create the pipeline using the trained modules and save it.
+ # Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
pipeline_args = {}
@@ -1384,7 +1379,6 @@ def compute_text_embeddings(prompt):
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
- variant=args.variant,
**pipeline_args,
)
diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py
index 3dbb6430ea2c..4ac4f969ee69 100644
--- a/examples/dreambooth/train_dreambooth_flax.py
+++ b/examples/dreambooth/train_dreambooth_flax.py
@@ -1,8 +1,10 @@
import argparse
+import hashlib
import logging
import math
import os
from pathlib import Path
+from typing import Optional
import jax
import jax.numpy as jnp
@@ -14,8 +16,7 @@
from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard
-from huggingface_hub import create_repo, upload_folder
-from huggingface_hub.utils import insecure_hashlib
+from huggingface_hub import HfFolder, Repository, create_repo, whoami
from jax.experimental.compilation_cache import compilation_cache as cc
from PIL import Image
from torch.utils.data import Dataset
@@ -35,7 +36,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
@@ -317,6 +318,16 @@ def __getitem__(self, index):
return example
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
def get_params_to_save(params):
return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))
@@ -373,7 +384,7 @@ def main():
images = pipeline.numpy_to_pil(np.array(images))
for i, image in enumerate(images):
- hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
@@ -381,13 +392,21 @@ def main():
# Handle the repository creation
if jax.process_index() == 0:
- if args.output_dir is not None:
- os.makedirs(args.output_dir, exist_ok=True)
-
if args.push_to_hub:
- repo_id = create_repo(
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
- ).repo_id
+ if args.hub_model_id is None:
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
+ else:
+ repo_name = args.hub_model_id
+ create_repo(repo_name, exist_ok=True, token=args.hub_token)
+ repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
+
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
+ if "step_*" not in gitignore:
+ gitignore.write("step_*\n")
+ if "epoch_*" not in gitignore:
+ gitignore.write("epoch_*\n")
+ elif args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer and add the placeholder token as a additional special token
if args.tokenizer_name:
@@ -460,10 +479,7 @@ def collate_fn(examples):
# Load models and create wrapper for stable diffusion
text_encoder = FlaxCLIPTextModel.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="text_encoder",
- dtype=weight_dtype,
- revision=args.revision,
+ args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
vae_arg,
@@ -471,10 +487,7 @@ def collate_fn(examples):
**vae_kwargs,
)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="unet",
- dtype=weight_dtype,
- revision=args.revision,
+ args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision
)
# Optimization
@@ -655,12 +668,7 @@ def checkpoint(step=None):
if args.push_to_hub:
message = f"checkpoint-{step}" if step is not None else "End of training"
- upload_folder(
- repo_id=repo_id,
- folder_path=args.output_dir,
- commit_message=message,
- ignore_patterns=["step_*", "epoch_*"],
- )
+ repo.push_to_hub(commit_message=message, blocking=False, auto_lfs_prune=True)
global_step = 0
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index 80889956f221..ac72974c4a1c 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -16,6 +16,7 @@
import argparse
import copy
import gc
+import hashlib
import itertools
import logging
import math
@@ -33,7 +34,6 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
-from huggingface_hub.utils import insecure_hashlib
from packaging import version
from PIL import Image
from PIL.ImageOps import exif_transpose
@@ -51,7 +51,10 @@
StableDiffusionPipeline,
UNet2DConditionModel,
)
-from diffusers.loaders import LoraLoaderMixin
+from diffusers.loaders import (
+ LoraLoaderMixin,
+ text_encoder_lora_state_dict,
+)
from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
@@ -65,44 +68,11 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__)
-# TODO: This function should be removed once training scripts are rewritten in PEFT
-def text_encoder_lora_state_dict(text_encoder):
- state_dict = {}
-
- def text_encoder_attn_modules(text_encoder):
- from transformers import CLIPTextModel, CLIPTextModelWithProjection
-
- attn_modules = []
-
- if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
- for i, layer in enumerate(text_encoder.text_model.encoder.layers):
- name = f"text_model.encoder.layers.{i}.self_attn"
- mod = layer.self_attn
- attn_modules.append((name, mod))
-
- return attn_modules
-
- for name, module in text_encoder_attn_modules(text_encoder):
- for k, v in module.q_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
-
- for k, v in module.k_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
-
- for k, v in module.v_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
-
- for k, v in module.out_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
-
- return state_dict
-
-
def save_model_card(
repo_id: str,
images=None,
@@ -183,12 +153,6 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -756,7 +720,6 @@ def main(args):
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
- variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -775,7 +738,7 @@ def main(args):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
- hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
@@ -810,11 +773,11 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
try:
vae = AutoencoderKL.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
)
except OSError:
# IF does not have a VAE so let's just set it to None
@@ -822,7 +785,7 @@ def main(args):
vae = None
unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# We only train the additional adapter LoRA layers
@@ -831,7 +794,7 @@ def main(args):
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
- # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
@@ -1317,7 +1280,6 @@ def compute_text_embeddings(prompt):
unet=accelerator.unwrap_model(unet),
text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
@@ -1403,7 +1365,7 @@ def compute_text_embeddings(prompt):
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
- args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
+ args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 38f0860ab77e..8ef666840b3a 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -15,6 +15,7 @@
import argparse
import gc
+import hashlib
import itertools
import logging
import math
@@ -30,9 +31,8 @@
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
-from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
-from huggingface_hub.utils import insecure_hashlib
from packaging import version
from PIL import Image
from PIL.ImageOps import exif_transpose
@@ -49,114 +49,51 @@
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
-from diffusers.loaders import LoraLoaderMixin
+from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
-from diffusers.training_utils import compute_snr, unet_lora_state_dict
+from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__)
-# TODO: This function should be removed once training scripts are rewritten in PEFT
-def text_encoder_lora_state_dict(text_encoder):
- state_dict = {}
-
- def text_encoder_attn_modules(text_encoder):
- from transformers import CLIPTextModel, CLIPTextModelWithProjection
-
- attn_modules = []
-
- if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
- for i, layer in enumerate(text_encoder.text_model.encoder.layers):
- name = f"text_model.encoder.layers.{i}.self_attn"
- mod = layer.self_attn
- attn_modules.append((name, mod))
-
- return attn_modules
-
- for name, module in text_encoder_attn_modules(text_encoder):
- for k, v in module.q_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
-
- for k, v in module.k_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
-
- for k, v in module.v_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
-
- for k, v in module.out_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
-
- return state_dict
-
-
def save_model_card(
- repo_id: str,
- images=None,
- base_model=str,
- train_text_encoder=False,
- instance_prompt=str,
- validation_prompt=str,
- repo_folder=None,
- vae_path=None,
+ repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
):
- img_str = "widget:\n" if images else ""
+ img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
- img_str += f"""
- - text: '{validation_prompt if validation_prompt else ' ' }'
- output:
- url:
- "image_{i}.png"
- """
+ img_str += f"\n"
yaml = f"""
---
+license: openrail++
+base_model: {base_model}
+instance_prompt: {prompt}
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- text-to-image
- diffusers
- lora
-- template:sd-lora
-{img_str}
-base_model: {base_model}
-instance_prompt: {instance_prompt}
-license: openrail++
+inference: true
---
"""
-
model_card = f"""
-# SDXL LoRA DreamBooth - {repo_id}
-
-
+# LoRA DreamBooth - {repo_id}
-## Model description
-
-These are {repo_id} LoRA adaption weights for {base_model}.
-
-The weights were trained using [DreamBooth](https://dreambooth.github.io/).
+These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
+{img_str}
LoRA for the text encoder was enabled: {train_text_encoder}.
Special VAE used for training: {vae_path}.
-
-## Trigger words
-
-You should use {instance_prompt} to trigger the image generation.
-
-## Download model
-
-Weights for this model are available in Safetensors format.
-
-[Download]({repo_id}/tree/main) them in the Files & versions tab.
-
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
@@ -204,59 +141,13 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
- parser.add_argument(
- "--dataset_name",
- type=str,
- default=None,
- help=(
- "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
- " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
- " or to a folder containing files that 🤗 Datasets can understand."
- ),
- )
- parser.add_argument(
- "--dataset_config_name",
- type=str,
- default=None,
- help="The config of the Dataset, leave as None if there's only one config.",
- )
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
- help=("A folder containing the training data. "),
- )
-
- parser.add_argument(
- "--cache_dir",
- type=str,
- default=None,
- help="The directory where the downloaded models and datasets will be stored.",
- )
-
- parser.add_argument(
- "--image_column",
- type=str,
- default="image",
- help="The column of the dataset containing the target image. By "
- "default, the standard Image Dataset maps out 'file_name' "
- "to 'image'.",
- )
- parser.add_argument(
- "--caption_column",
- type=str,
- default=None,
- help="The column of the dataset containing the instance prompt for each image",
+ required=True,
+ help="A folder containing the training data of instance images.",
)
-
- parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
-
parser.add_argument(
"--class_data_dir",
type=str,
@@ -269,7 +160,7 @@ def parse_args(input_args=None):
type=str,
default=None,
required=True,
- help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ help="The prompt with identifier specifying the instance",
)
parser.add_argument(
"--class_prompt",
@@ -408,16 +299,9 @@ def parse_args(input_args=None):
parser.add_argument(
"--learning_rate",
type=float,
- default=1e-4,
+ default=5e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
-
- parser.add_argument(
- "--text_encoder_lr",
- type=float,
- default=5e-6,
- help="Text encoder learning rate to use.",
- )
parser.add_argument(
"--scale_lr",
action="store_true",
@@ -433,14 +317,6 @@ def parse_args(input_args=None):
' "constant", "constant_with_warmup"]'
),
)
-
- parser.add_argument(
- "--snr_gamma",
- type=float,
- default=None,
- help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
- )
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
@@ -459,59 +335,13 @@ def parse_args(input_args=None):
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
-
parser.add_argument(
- "--optimizer",
- type=str,
- default="AdamW",
- help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
- )
-
- parser.add_argument(
- "--use_8bit_adam",
- action="store_true",
- help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
- )
-
- parser.add_argument(
- "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
- )
- parser.add_argument(
- "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
- )
- parser.add_argument(
- "--prodigy_beta3",
- type=float,
- default=None,
- help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
- "uses the value of square root of beta2. Ignored if optimizer is adamW",
- )
- parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
- parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
- parser.add_argument(
- "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
- )
-
- parser.add_argument(
- "--adam_epsilon",
- type=float,
- default=1e-08,
- help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
- )
-
- parser.add_argument(
- "--prodigy_use_bias_correction",
- type=bool,
- default=True,
- help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
- )
- parser.add_argument(
- "--prodigy_safeguard_warmup",
- type=bool,
- default=True,
- help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
- "Ignored if optimizer is adamW",
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
@@ -584,12 +414,6 @@ def parse_args(input_args=None):
else:
args = parser.parse_args()
- if args.dataset_name is None and args.instance_data_dir is None:
- raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
-
- if args.dataset_name is not None and args.instance_data_dir is not None:
- raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
-
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
@@ -618,84 +442,20 @@ class DreamBoothDataset(Dataset):
def __init__(
self,
instance_data_root,
- instance_prompt,
- class_prompt,
class_data_root=None,
class_num=None,
size=1024,
- repeats=1,
center_crop=False,
):
self.size = size
self.center_crop = center_crop
- self.instance_prompt = instance_prompt
- self.custom_instance_prompts = None
- self.class_prompt = class_prompt
-
- # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
- # we load the training data using load_dataset
- if args.dataset_name is not None:
- try:
- from datasets import load_dataset
- except ImportError:
- raise ImportError(
- "You are trying to load your data using the datasets library. If you wish to train using custom "
- "captions please install the datasets library: `pip install datasets`. If you wish to load a "
- "local folder containing images only, specify --instance_data_dir instead."
- )
- # Downloading and loading a dataset from the hub.
- # See more about loading custom images at
- # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
- dataset = load_dataset(
- args.dataset_name,
- args.dataset_config_name,
- cache_dir=args.cache_dir,
- )
- # Preprocessing the datasets.
- column_names = dataset["train"].column_names
-
- # 6. Get the column names for input/target.
- if args.image_column is None:
- image_column = column_names[0]
- logger.info(f"image column defaulting to {image_column}")
- else:
- image_column = args.image_column
- if image_column not in column_names:
- raise ValueError(
- f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
- )
- instance_images = dataset["train"][image_column]
-
- if args.caption_column is None:
- logger.info(
- "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
- "contains captions/prompts for the images, make sure to specify the "
- "column as --caption_column"
- )
- self.custom_instance_prompts = None
- else:
- if args.caption_column not in column_names:
- raise ValueError(
- f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
- )
- custom_instance_prompts = dataset["train"][args.caption_column]
- # create final list of captions according to --repeats
- self.custom_instance_prompts = []
- for caption in custom_instance_prompts:
- self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
- else:
- self.instance_data_root = Path(instance_data_root)
- if not self.instance_data_root.exists():
- raise ValueError("Instance images root doesn't exists.")
-
- instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
- self.custom_instance_prompts = None
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
- self.instance_images = []
- for img in instance_images:
- self.instance_images.extend(itertools.repeat(img, repeats))
- self.num_instance_images = len(self.instance_images)
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
self._length = self.num_instance_images
if class_data_root is not None:
@@ -724,23 +484,13 @@ def __len__(self):
def __getitem__(self, index):
example = {}
- instance_image = self.instance_images[index % self.num_instance_images]
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
instance_image = exif_transpose(instance_image)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
- if self.custom_instance_prompts:
- caption = self.custom_instance_prompts[index % self.num_instance_images]
- if caption:
- example["instance_prompt"] = caption
- else:
- example["instance_prompt"] = self.instance_prompt
-
- else: # costum prompts were provided, but length does not match size of image dataset
- example["instance_prompt"] = self.instance_prompt
-
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
class_image = exif_transpose(class_image)
@@ -748,25 +498,22 @@ def __getitem__(self, index):
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
- example["class_prompt"] = self.class_prompt
return example
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples]
- prompts = [example["instance_prompt"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
pixel_values += [example["class_images"] for example in examples]
- prompts += [example["class_prompt"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
- batch = {"pixel_values": pixel_values, "prompts": prompts}
+ batch = {"pixel_values": pixel_values}
return batch
@@ -832,13 +579,12 @@ def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
- kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
- kwargs_handlers=[kwargs],
)
if args.report_to == "wandb":
@@ -883,7 +629,6 @@ def main(args):
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
revision=args.revision,
- variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -902,7 +647,7 @@ def main(args):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
- hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
@@ -922,16 +667,10 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer",
- revision=args.revision,
- use_fast=False,
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
)
tokenizer_two = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer_2",
- revision=args.revision,
- use_fast=False,
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
)
# import correct text encoder classes
@@ -945,10 +684,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
vae_path = (
args.pretrained_model_name_or_path
@@ -956,13 +695,10 @@ def main(args):
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
- vae_path,
- subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
- revision=args.revision,
- variant=args.variant,
+ vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
)
unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# We only train the additional adapter LoRA layers
@@ -971,7 +707,7 @@ def main(args):
text_encoder_two.requires_grad_(False)
unet.requires_grad_(False)
- # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
@@ -995,8 +731,7 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
- "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
- "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
@@ -1130,119 +865,35 @@ def load_model_hook(models, input_dir):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
- # Optimization parameters
- unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
- if args.train_text_encoder:
- # different learning rate for text encoder and unet
- text_lora_parameters_one_with_lr = {
- "params": text_lora_parameters_one,
- "weight_decay": args.adam_weight_decay_text_encoder,
- "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
- }
- text_lora_parameters_two_with_lr = {
- "params": text_lora_parameters_two,
- "weight_decay": args.adam_weight_decay_text_encoder,
- "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
- }
- params_to_optimize = [
- unet_lora_parameters_with_lr,
- text_lora_parameters_one_with_lr,
- text_lora_parameters_two_with_lr,
- ]
- else:
- params_to_optimize = [unet_lora_parameters_with_lr]
-
- # Optimizer creation
- if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
- logger.warn(
- f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
- "Defaulting to adamW"
- )
- args.optimizer = "adamw"
-
- if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
- logger.warn(
- f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
- f"set to {args.optimizer.lower()}"
- )
-
- if args.optimizer.lower() == "adamw":
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError(
- "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
- )
-
- optimizer_class = bnb.optim.AdamW8bit
- else:
- optimizer_class = torch.optim.AdamW
-
- optimizer = optimizer_class(
- params_to_optimize,
- betas=(args.adam_beta1, args.adam_beta2),
- weight_decay=args.adam_weight_decay,
- eps=args.adam_epsilon,
- )
-
- if args.optimizer.lower() == "prodigy":
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
try:
- import prodigyopt
+ import bitsandbytes as bnb
except ImportError:
- raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
-
- optimizer_class = prodigyopt.Prodigy
-
- if args.learning_rate <= 0.1:
- logger.warn(
- "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
- )
- if args.train_text_encoder and args.text_encoder_lr:
- logger.warn(
- f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
- f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
- f"When using prodigy only learning_rate is used as the initial learning rate."
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
- # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
- # --learning_rate
- params_to_optimize[1]["lr"] = args.learning_rate
- params_to_optimize[2]["lr"] = args.learning_rate
-
- optimizer = optimizer_class(
- params_to_optimize,
- lr=args.learning_rate,
- betas=(args.adam_beta1, args.adam_beta2),
- beta3=args.prodigy_beta3,
- weight_decay=args.adam_weight_decay,
- eps=args.adam_epsilon,
- decouple=args.prodigy_decouple,
- use_bias_correction=args.prodigy_use_bias_correction,
- safeguard_warmup=args.prodigy_safeguard_warmup,
- )
- # Dataset and DataLoaders creation:
- train_dataset = DreamBoothDataset(
- instance_data_root=args.instance_data_dir,
- instance_prompt=args.instance_prompt,
- class_prompt=args.class_prompt,
- class_data_root=args.class_data_dir if args.with_prior_preservation else None,
- class_num=args.num_class_images,
- size=args.resolution,
- repeats=args.repeats,
- center_crop=args.center_crop,
- )
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset,
- batch_size=args.train_batch_size,
- shuffle=True,
- collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
- num_workers=args.dataloader_num_workers,
+ # Optimizer creation
+ params_to_optimize = (
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
+ if args.train_text_encoder
+ else unet_lora_parameters
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
)
# Computes additional embeddings/ids required by the SDXL UNet.
- # regular text embeddings (when `train_text_encoder` is not True)
+ # regular text emebddings (when `train_text_encoder` is not True)
# pooled text embeddings
# time ids
@@ -1269,11 +920,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
# Handle instance prompt.
instance_time_ids = compute_time_ids()
-
- # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
- # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
- # the redundant encoding.
- if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
args.instance_prompt, text_encoders, tokenizers
)
@@ -1286,36 +933,49 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
args.class_prompt, text_encoders, tokenizers
)
- # Clear the memory here
- if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ # Clear the memory here.
+ if not args.train_text_encoder:
del tokenizers, text_encoders
gc.collect()
torch.cuda.empty_cache()
- # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
- # pack the statically computed variables appropriately here. This is so that we don't
+ # Pack the statically computed variables appropriately. This is so that we don't
# have to pass them to the dataloader.
add_time_ids = instance_time_ids
if args.with_prior_preservation:
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
- if not train_dataset.custom_instance_prompts:
- if not args.train_text_encoder:
- prompt_embeds = instance_prompt_hidden_states
- unet_add_text_embeds = instance_pooled_prompt_embeds
- if args.with_prior_preservation:
- prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
- unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
- # if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
- # batch prompts on all training steps
- else:
- tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
- tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
- if args.with_prior_preservation:
- class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
- class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
- tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
- tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+ if not args.train_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ unet_add_text_embeds = instance_pooled_prompt_embeds
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
@@ -1410,25 +1070,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
if args.train_text_encoder:
text_encoder_one.train()
text_encoder_two.train()
-
- # set top parameter requires_grad = True for gradient checkpointing works
- text_encoder_one.text_model.embeddings.requires_grad_(True)
- text_encoder_two.text_model.embeddings.requires_grad_(True)
-
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
- prompts = batch["prompts"]
-
- # encode batch prompts when custom prompts are provided for each image -
- if train_dataset.custom_instance_prompts:
- if not args.train_text_encoder:
- prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
- prompts, text_encoders, tokenizers
- )
- else:
- tokens_one = tokenize_prompt(tokenizer_one, prompts)
- tokens_two = tokenize_prompt(tokenizer_two, prompts)
# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
@@ -1449,21 +1093,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
- # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
- if not train_dataset.custom_instance_prompts:
- elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
- elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
- else:
- elems_to_repeat_text_embeds = 1
- elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
+ # Calculate the elements to repeat depending on the use of prior-preservation.
+ elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz
# Predict the noise residual
if not args.train_text_encoder:
unet_added_conditions = {
- "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
- "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
+ "time_ids": add_time_ids.repeat(elems_to_repeat, 1),
+ "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
}
- prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
model_pred = unet(
noisy_model_input,
timesteps,
@@ -1471,17 +1110,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
added_cond_kwargs=unet_added_conditions,
).sample
else:
- unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
+ unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)}
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None,
prompt=None,
text_input_ids_list=[tokens_one, tokens_two],
)
- unet_added_conditions.update(
- {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
- )
- prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
).sample
@@ -1499,34 +1136,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
- # Compute prior loss
- prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
-
- if args.snr_gamma is None:
+ # Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
- else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
- # Since we predict the noise instead of x_0, the original formulation is slightly changed.
- # This is discussed in Section 4.2 of the same paper.
- snr = compute_snr(noise_scheduler, timesteps)
- base_weight = (
- torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
- )
- if noise_scheduler.config.prediction_type == "v_prediction":
- # Velocity objective needs to be floored to an SNR weight of one.
- mse_loss_weights = base_weight + 1
- else:
- # Epsilon and sample both use the same loss weights.
- mse_loss_weights = base_weight
-
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
- loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
- loss = loss.mean()
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
- if args.with_prior_preservation:
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
@@ -1587,16 +1206,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
# create pipeline
if not args.train_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="text_encoder",
- revision=args.revision,
- variant=args.variant,
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="text_encoder_2",
- revision=args.revision,
- variant=args.variant,
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
@@ -1605,7 +1218,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet),
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
@@ -1683,15 +1295,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- vae=vae,
- revision=args.revision,
- variant=args.variant,
- torch_dtype=weight_dtype,
+ args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
@@ -1740,8 +1347,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
- instance_prompt=args.instance_prompt,
- validation_prompt=args.validation_prompt,
+ prompt=args.instance_prompt,
repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path,
)
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py
index 8391713068df..5f8a2d9ee150 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py
@@ -52,7 +52,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -78,12 +78,6 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
parser.add_argument(
"--dataset_name",
type=str,
@@ -441,11 +435,9 @@ def main():
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = CLIPTextModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
- )
- vae = AutoencoderKL.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
)
@@ -923,7 +915,6 @@ def collate_fn(examples):
text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae),
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -975,7 +966,6 @@ def collate_fn(examples):
vae=accelerator.unwrap_model(vae),
unet=unet,
revision=args.revision,
- variant=args.variant,
)
pipeline.save_pretrained(args.output_dir)
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
index 89fdf2cde9c2..e2d9b2105160 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
@@ -55,7 +55,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -118,12 +118,6 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
parser.add_argument(
"--dataset_name",
type=str,
@@ -490,10 +484,9 @@ def main():
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
- variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# InstructPix2Pix uses an additional image for conditioning. To accommodate that,
@@ -702,16 +695,10 @@ def preprocess_images(examples):
# Load scheduler, tokenizer and models.
tokenizer_1 = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer",
- revision=args.revision,
- use_fast=False,
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
)
tokenizer_2 = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer_2",
- revision=args.revision,
- use_fast=False,
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
)
text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
text_encoder_cls_2 = import_model_class_from_model_name_or_path(
@@ -721,10 +708,10 @@ def preprocess_images(examples):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_1 = text_encoder_cls_1.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
text_encoder_2 = text_encoder_cls_2.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
# We ALWAYS pre-compute the additional condition embeddings needed for SDXL
@@ -1122,7 +1109,6 @@ def collate_fn(examples):
tokenizer_2=tokenizer_2,
vae=vae,
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -1190,7 +1176,6 @@ def collate_fn(examples):
vae=vae,
unet=unet,
revision=args.revision,
- variant=args.variant,
)
pipeline.save_pretrained(args.output_dir)
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
index 5efa3c9eeca9..4ca95ecebea9 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
@@ -52,7 +52,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.21.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
index 49bb4d0a4fd8..19245724ecf5 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
@@ -46,7 +46,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.21.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
index f8f44e9c8953..7305137218ef 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
@@ -46,7 +46,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.21.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -682,7 +682,7 @@ def collate_fn(examples):
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
- accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm)
+ accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
index 55873ae6da7a..d21eaf3dd0b0 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
@@ -51,7 +51,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.21.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/research_projects/colossalai/README.md b/examples/research_projects/colossalai/README.md
index be94950b772e..7c428bbce736 100644
--- a/examples/research_projects/colossalai/README.md
+++ b/examples/research_projects/colossalai/README.md
@@ -41,7 +41,7 @@ The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Mode
## Training
-The argument `placement` can be `cpu`, `auto`, `cuda`, with `cpu` the GPU RAM required can be minimized to 4GB but will deceleration, with `cuda` you can also reduce GPU memory by half but accelerated training, with `auto` a more balanced solution for speed and memory can be obtained。
+The arguement `placement` can be `cpu`, `auto`, `cuda`, with `cpu` the GPU RAM required can be minimized to 4GB but will deceleration, with `cuda` you can also reduce GPU memory by half but accelerated training, with `auto` a more balanced solution for speed and memory can be obtained。
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
diff --git a/examples/research_projects/colossalai/train_dreambooth_colossalai.py b/examples/research_projects/colossalai/train_dreambooth_colossalai.py
index 5cebd2b81175..3d4466bf94b7 100644
--- a/examples/research_projects/colossalai/train_dreambooth_colossalai.py
+++ b/examples/research_projects/colossalai/train_dreambooth_colossalai.py
@@ -1,4 +1,5 @@
import argparse
+import hashlib
import math
import os
from pathlib import Path
@@ -15,7 +16,6 @@
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from huggingface_hub import create_repo, upload_folder
-from huggingface_hub.utils import insecure_hashlib
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
@@ -394,7 +394,7 @@ def main(args):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
- hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
@@ -464,7 +464,9 @@ def main(args):
unet = gemini_zero_dpp(unet, args.placement)
# config optimizer for colossalai zero
- optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
+ optimizer = GeminiAdamOptimizer(
+ unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm
+ )
# load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
diff --git a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py
index 0e82a45c024f..a3eaba014cf6 100644
--- a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py
+++ b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py
@@ -1,4 +1,5 @@
import argparse
+import hashlib
import itertools
import math
import os
@@ -13,7 +14,6 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
-from huggingface_hub.utils import insecure_hashlib
from PIL import Image, ImageDraw
from torch.utils.data import Dataset
from torchvision import transforms
@@ -465,7 +465,7 @@ def main():
images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images
for i, image in enumerate(images):
- hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
diff --git a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py
index 3d79b2ceadaf..d25c6d22f8e7 100644
--- a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py
+++ b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py
@@ -1,4 +1,5 @@
import argparse
+import hashlib
import math
import os
import random
@@ -12,7 +13,6 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
-from huggingface_hub.utils import insecure_hashlib
from PIL import Image, ImageDraw
from torch.utils.data import Dataset
from torchvision import transforms
@@ -464,7 +464,7 @@ def main():
images = pipeline(prompt=example["prompt"], mask_image=fake_mask, image=fake_pil_images).images
for i, image in enumerate(images):
- hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
diff --git a/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py b/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py
index 43667187596e..b19dd6e1103d 100644
--- a/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py
+++ b/examples/research_projects/intel_opts/textual_inversion_dfq/textual_inversion.py
@@ -4,7 +4,7 @@
import os
import random
from pathlib import Path
-from typing import Iterable
+from typing import Iterable, Optional
import numpy as np
import PIL
@@ -13,7 +13,7 @@
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, set_seed
-from huggingface_hub import create_repo, upload_folder
+from huggingface_hub import HfFolder, Repository, whoami
from neural_compressor.utils import logger
from packaging import version
from PIL import Image
@@ -413,6 +413,16 @@ def __getitem__(self, i):
return example
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
def freeze_params(params):
for param in params:
param.requires_grad = False
@@ -451,13 +461,20 @@ def main():
# Handle the repository creation
if accelerator.is_main_process:
- if args.output_dir is not None:
- os.makedirs(args.output_dir, exist_ok=True)
-
if args.push_to_hub:
- repo_id = create_repo(
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
- ).repo_id
+ if args.hub_model_id is None:
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
+ else:
+ repo_name = args.hub_model_id
+ repo = Repository(args.output_dir, clone_from=repo_name)
+
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
+ if "step_*" not in gitignore:
+ gitignore.write("step_*\n")
+ if "epoch_*" not in gitignore:
+ gitignore.write("epoch_*\n")
+ elif args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer and add the placeholder token as a additional special token
if args.tokenizer_name:
@@ -965,12 +982,7 @@ def attention_fetcher(x):
)
if args.push_to_hub:
- upload_folder(
- repo_id=repo_id,
- folder_path=args.output_dir,
- commit_message="End of training",
- ignore_patterns=["step_*", "epoch_*"],
- )
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
accelerator.end_training()
diff --git a/examples/research_projects/multi_subject_dreambooth/README.md b/examples/research_projects/multi_subject_dreambooth/README.md
index 5fff305f82be..d1a7705cfebb 100644
--- a/examples/research_projects/multi_subject_dreambooth/README.md
+++ b/examples/research_projects/multi_subject_dreambooth/README.md
@@ -323,7 +323,7 @@ accelerate launch train_dreambooth.py \
### Using DreamBooth for other pipelines than Stable Diffusion
-Altdiffusion also support dreambooth now, the runing comman is basically the same as above, all you need to do is replace the `MODEL_NAME` like this:
+Altdiffusion also support dreambooth now, the runing comman is basically the same as abouve, all you need to do is replace the `MODEL_NAME` like this:
One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion).
```
diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
index d58c4009b69a..4e03e23fc128 100644
--- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
+++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
@@ -1,4 +1,5 @@
import argparse
+import hashlib
import itertools
import json
import logging
@@ -20,7 +21,6 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
-from huggingface_hub.utils import insecure_hashlib
from PIL import Image
from torch import dtype
from torch.nn import Module
@@ -843,7 +843,7 @@ def main(args):
images = pipeline(example["prompt"]).images
for ii, image in enumerate(images):
- hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = (
class_images_dir / f"{example['index'][ii] + cur_class_images}-{hash_image}.jpg"
)
diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
index 5cad9f2fbed9..ba5ccd238fdc 100644
--- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
+++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
@@ -4,6 +4,7 @@
import math
import os
from pathlib import Path
+from typing import Optional
import accelerate
import datasets
@@ -13,7 +14,7 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset
-from huggingface_hub import create_repo, upload_folder
+from huggingface_hub import HfFolder, Repository, create_repo, whoami
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
from onnxruntime.training.ortmodule import ORTModule
from packaging import version
@@ -276,6 +277,16 @@ def parse_args():
return args
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
@@ -349,13 +360,21 @@ def load_model_hook(models, input_dir):
# Handle the repository creation
if accelerator.is_main_process:
- if args.output_dir is not None:
- os.makedirs(args.output_dir, exist_ok=True)
-
if args.push_to_hub:
- repo_id = create_repo(
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
- ).repo_id
+ if args.hub_model_id is None:
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
+ else:
+ repo_name = args.hub_model_id
+ create_repo(repo_name, exist_ok=True, token=args.hub_token)
+ repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
+
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
+ if "step_*" not in gitignore:
+ gitignore.write("step_*\n")
+ if "epoch_*" not in gitignore:
+ gitignore.write("epoch_*\n")
+ elif args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
# Initialize the model
if args.model_config_name_or_path is None:
@@ -672,12 +691,7 @@ def transform_images(examples):
ema_model.restore(unet.parameters())
if args.push_to_hub:
- upload_folder(
- repo_id=repo_id,
- folder_path=args.output_dir,
- commit_message=f"Epoch {epoch}",
- ignore_patterns=["step_*", "epoch_*"],
- )
+ repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
accelerator.end_training()
diff --git a/examples/research_projects/sdxl_flax/README.md b/examples/research_projects/sdxl_flax/README.md
index 612fdf1edd43..fca21912982a 100644
--- a/examples/research_projects/sdxl_flax/README.md
+++ b/examples/research_projects/sdxl_flax/README.md
@@ -151,7 +151,7 @@ telling JAX which input arguments are static, that is, arguments that
are known at compile time and won't change. In our case, it is num_inference_steps,
height, width and return_latents.
-Once the function is compiled, these parameters are omitted from future calls and
+Once the function is compiled, these parameters are ommited from future calls and
cannot be changed without modifying the code and recompiling.
```python
diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
index 4b1687cbeba0..e23be2d754fe 100644
--- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py
+++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
@@ -58,7 +58,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__)
@@ -85,7 +85,6 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
unet=unet,
adapter=adapter,
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -263,12 +262,6 @@ def parse_args(input_args=None):
" float32 precision."
),
)
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -819,16 +812,10 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer",
- revision=args.revision,
- use_fast=False,
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
)
tokenizer_two = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer_2",
- revision=args.revision,
- use_fast=False,
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
)
# import correct text encoder classes
@@ -842,10 +829,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
vae_path = (
args.pretrained_model_name_or_path
@@ -856,10 +843,9 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
- variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
if args.adapter_model_name_or_path:
diff --git a/examples/test_examples.py b/examples/test_examples.py
new file mode 100644
index 000000000000..89e866231e89
--- /dev/null
+++ b/examples/test_examples.py
@@ -0,0 +1,1682 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc..
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import os
+import shutil
+import subprocess
+import sys
+import tempfile
+import unittest
+from typing import List
+
+import safetensors
+from accelerate.utils import write_basic_config
+
+from diffusers import DiffusionPipeline, UNet2DConditionModel
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+
+
+# These utils relate to ensuring the right error message is received when running scripts
+class SubprocessCallException(Exception):
+ pass
+
+
+def run_command(command: List[str], return_stdout=False):
+ """
+ Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
+ if an error occurred while running `command`
+ """
+ try:
+ output = subprocess.check_output(command, stderr=subprocess.STDOUT)
+ if return_stdout:
+ if hasattr(output, "decode"):
+ output = output.decode("utf-8")
+ return output
+ except subprocess.CalledProcessError as e:
+ raise SubprocessCallException(
+ f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
+ ) from e
+
+
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class ExamplesTestsAccelerate(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ super().setUpClass()
+ cls._tmpdir = tempfile.mkdtemp()
+ cls.configPath = os.path.join(cls._tmpdir, "default_config.yml")
+
+ write_basic_config(save_location=cls.configPath)
+ cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
+
+ @classmethod
+ def tearDownClass(cls):
+ super().tearDownClass()
+ shutil.rmtree(cls._tmpdir)
+
+ def test_train_unconditional(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 2
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ """.split()
+
+ run_command(self._launch_args + test_args, return_stdout=True)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_textual_inversion(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --validation_prompt
+ --validation_steps 1
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
+
+ def test_dreambooth(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_if(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --pre_compute_text_embeddings
+ --tokenizer_max_length=77
+ --text_encoder_use_attention_mask
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_checkpointing(self):
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 5, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt {instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 5
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check can run the original fully trained output pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=2)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt {instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(instance_prompt, num_inference_steps=2)
+
+ # check old checkpoints do not exist
+ self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+
+ # check new checkpoints exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
+
+ def test_dreambooth_lora(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --train_text_encoder
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # check `text_encoder` is present at all.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ keys = lora_state_dict.keys()
+ is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
+ self.assertTrue(is_text_encoder_present)
+
+ # the names of the keys of the state dict should either start with `unet`
+ # or `text_encoder`.
+ is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
+ self.assertTrue(is_correct_naming)
+
+ def test_dreambooth_lora_if_model(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --pre_compute_text_embeddings
+ --tokenizer_max_length=77
+ --text_encoder_use_attention_mask
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` in their names.
+ starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_sdxl_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
+ keys = lora_state_dict.keys()
+ starts_with_unet = all(
+ k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
+ )
+ self.assertTrue(starts_with_unet)
+
+ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe("a prompt", num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt photo
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --train_text_encoder
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe("a prompt", num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_custom_diffusion(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir docs/source/en/imgs
+ --instance_prompt
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 1.0e-05
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --modifier_token
+ --no_safe_serialization
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_custom_diffusion_weights.bin")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, ".bin")))
+
+ def test_text_to_image(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_text_to_image_checkpointing(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 5, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 5
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {
+ # no checkpoint-2 -> check old checkpoints do not exist
+ # check new checkpoints exist
+ "checkpoint-4",
+ "checkpoint-6",
+ },
+ )
+
+ def test_text_to_image_checkpointing_use_ema(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 5, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 5
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --use_ema
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ # check can run an intermediate checkpoint
+ unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --use_ema
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {
+ # no checkpoint-2 -> check old checkpoints do not exist
+ # check new checkpoints exist
+ "checkpoint-4",
+ "checkpoint-6",
+ },
+ )
+
+ def test_text_to_image_checkpointing_checkpoints_total_limit(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 9, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4, 6, 8
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 9
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ # resume and we should try to checkpoint at 10, where we'll have to remove
+ # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 11
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+ def test_text_to_image_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):
+ prompt = "a prompt"
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
+ prompt = "a prompt"
+ pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
+ # Should create checkpoints at steps 2, 4, 6
+ # with checkpoint at step 2 deleted
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path {pipeline_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 7
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --train_text_encoder
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(pipeline_path)
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
+ prompt = "a prompt"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 9, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4, 6, 8
+
+ initial_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 9
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ # resume and we should try to checkpoint at 10, where we'll have to remove
+ # checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
+
+ resume_run_args = f"""
+ examples/text_to_image/train_text_to_image_lora.py
+ --pretrained_model_name_or_path {pretrained_model_name_or_path}
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --center_crop
+ --random_flip
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 11
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ --seed=0
+ --num_validation_images=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
+ )
+ pipe.load_lora_weights(tmpdir)
+ pipe(prompt, num_inference_steps=2)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+ def test_unconditional_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ initial_run_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ # checkpoint-2 should have been deleted
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ initial_run_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 1
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --checkpointing_steps=1
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-1", "checkpoint-2", "checkpoint-3", "checkpoint-4", "checkpoint-5", "checkpoint-6"},
+ )
+
+ resume_run_args = f"""
+ examples/unconditional_image_generation/train_unconditional.py
+ --dataset_name hf-internal-testing/dummy_image_class_data
+ --model_config_name_or_path diffusers/ddpm_dummy
+ --resolution 64
+ --output_dir {tmpdir}
+ --train_batch_size 1
+ --num_epochs 2
+ --gradient_accumulation_steps 1
+ --ddpm_num_inference_steps 2
+ --learning_rate 1e-3
+ --lr_warmup_steps 5
+ --resume_from_checkpoint=checkpoint-6
+ --checkpointing_steps=2
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-8", "checkpoint-10", "checkpoint-12"},
+ )
+
+ def test_textual_inversion_checkpointing(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --validation_prompt
+ --validation_steps 1
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 3
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-3"},
+ )
+
+ def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --validation_prompt
+ --validation_steps 1
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 3
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-1", "checkpoint-2", "checkpoint-3"},
+ )
+
+ resume_run_args = f"""
+ examples/textual_inversion/textual_inversion.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
+ --train_data_dir docs/source/en/imgs
+ --learnable_property object
+ --placeholder_token
+ --initializer_token a
+ --validation_prompt
+ --validation_steps 1
+ --save_steps 1
+ --num_vectors 2
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=1
+ --resume_from_checkpoint=checkpoint-3
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-3", "checkpoint-4"},
+ )
+
+ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/instruct_pix2pix/train_instruct_pix2pix.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/instructpix2pix-10-samples
+ --resolution=64
+ --random_flip
+ --train_batch_size=1
+ --max_train_steps=7
+ --checkpointing_steps=2
+ --checkpoints_total_limit=2
+ --output_dir {tmpdir}
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/instruct_pix2pix/train_instruct_pix2pix.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/instructpix2pix-10-samples
+ --resolution=64
+ --random_flip
+ --train_batch_size=1
+ --max_train_steps=9
+ --checkpointing_steps=2
+ --output_dir {tmpdir}
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/instruct_pix2pix/train_instruct_pix2pix.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/instructpix2pix-10-samples
+ --resolution=64
+ --random_flip
+ --train_batch_size=1
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --output_dir {tmpdir}
+ --seed=0
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check checkpoint directories exist
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/dreambooth/train_dreambooth.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+ def test_dreambooth_lora_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/dreambooth/train_dreambooth_lora.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=prompt
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+ def test_controlnet_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/controlnet/train_controlnet.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-8", "checkpoint-10", "checkpoint-12"},
+ )
+
+ def test_controlnet_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/controlnet/train_controlnet_sdxl.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
+
+ def test_t2i_adapter_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/t2i_adapter/train_t2i_adapter_sdxl.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --adapter_model_name_or_path=hf-internal-testing/tiny-adapter
+ --dataset_name=hf-internal-testing/fill10
+ --output_dir={tmpdir}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=9
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
+
+ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=
+ --resolution=64
+ --train_batch_size=1
+ --modifier_token=
+ --dataloader_num_workers=0
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ --no_safe_serialization
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=
+ --resolution=64
+ --train_batch_size=1
+ --modifier_token=
+ --dataloader_num_workers=0
+ --max_train_steps=9
+ --checkpointing_steps=2
+ --no_safe_serialization
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
+ )
+
+ resume_run_args = f"""
+ examples/custom_diffusion/train_custom_diffusion.py
+ --pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
+ --instance_data_dir=docs/source/en/imgs
+ --output_dir={tmpdir}
+ --instance_prompt=
+ --resolution=64
+ --train_batch_size=1
+ --modifier_token=
+ --dataloader_num_workers=0
+ --max_train_steps=11
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-8
+ --checkpoints_total_limit=3
+ --no_safe_serialization
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-6", "checkpoint-8", "checkpoint-10"},
+ )
+
+ def test_text_to_image_lora_sdxl(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ def test_text_to_image_lora_sdxl_with_text_encoder(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ examples/text_to_image/train_text_to_image_lora_sdxl.py
+ --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
+ --dataset_name hf-internal-testing/dummy_image_text_data
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --train_text_encoder
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
+ keys = lora_state_dict.keys()
+ starts_with_unet = all(
+ k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
+ )
+ self.assertTrue(starts_with_unet)
diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py
index 3bb5fc367c4d..e216529b2f54 100644
--- a/examples/text_to_image/train_text_to_image.py
+++ b/examples/text_to_image/train_text_to_image.py
@@ -53,7 +53,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -148,7 +148,6 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
unet=accelerator.unwrap_model(unet),
safety_checker=None,
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -210,12 +209,6 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
parser.add_argument(
"--dataset_name",
type=str,
@@ -574,10 +567,10 @@ def deepspeed_zero_init_disabled_context_manager():
# across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
text_encoder = CLIPTextModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
)
unet = UNet2DConditionModel.from_pretrained(
@@ -592,7 +585,7 @@ def deepspeed_zero_init_disabled_context_manager():
# Create EMA for the unet.
if args.use_ema:
ema_unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
@@ -1033,7 +1026,6 @@ def collate_fn(examples):
vae=vae,
unet=unet,
revision=args.revision,
- variant=args.variant,
)
pipeline.save_pretrained(args.output_dir)
diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py
index 64692ea3fcab..ac3afcbaba12 100644
--- a/examples/text_to_image/train_text_to_image_flax.py
+++ b/examples/text_to_image/train_text_to_image_flax.py
@@ -33,7 +33,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = logging.getLogger(__name__)
@@ -54,12 +54,6 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
parser.add_argument(
"--dataset_name",
type=str,
@@ -214,12 +208,6 @@ def parse_args():
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
- parser.add_argument(
- "--from_pt",
- action="store_true",
- default=False,
- help="Flag to indicate whether to convert models from PyTorch.",
- )
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -278,7 +266,9 @@ def main():
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
- args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
)
else:
data_files = {}
@@ -384,31 +374,16 @@ def collate_fn(examples):
# Load models and create wrapper for stable diffusion
tokenizer = CLIPTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- from_pt=args.from_pt,
- revision=args.revision,
- subfolder="tokenizer",
+ args.pretrained_model_name_or_path, revision=args.revision, subfolder="tokenizer"
)
text_encoder = FlaxCLIPTextModel.from_pretrained(
- args.pretrained_model_name_or_path,
- from_pt=args.from_pt,
- revision=args.revision,
- subfolder="text_encoder",
- dtype=weight_dtype,
+ args.pretrained_model_name_or_path, revision=args.revision, subfolder="text_encoder", dtype=weight_dtype
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
- args.pretrained_model_name_or_path,
- from_pt=args.from_pt,
- revision=args.revision,
- subfolder="vae",
- dtype=weight_dtype,
+ args.pretrained_model_name_or_path, revision=args.revision, subfolder="vae", dtype=weight_dtype
)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path,
- from_pt=args.from_pt,
- revision=args.revision,
- subfolder="unet",
- dtype=weight_dtype,
+ args.pretrained_model_name_or_path, revision=args.revision, subfolder="unet", dtype=weight_dtype
)
# Optimization
diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py
index 86a06831bd34..eac0f18f49f4 100644
--- a/examples/text_to_image/train_text_to_image_lora.py
+++ b/examples/text_to_image/train_text_to_image_lora.py
@@ -40,7 +40,8 @@
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
-from diffusers.models.lora import LoRALinearLayer
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
@@ -48,44 +49,11 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__, log_level="INFO")
-# TODO: This function should be removed once training scripts are rewritten in PEFT
-def text_encoder_lora_state_dict(text_encoder):
- state_dict = {}
-
- def text_encoder_attn_modules(text_encoder):
- from transformers import CLIPTextModel, CLIPTextModelWithProjection
-
- attn_modules = []
-
- if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
- for i, layer in enumerate(text_encoder.text_model.encoder.layers):
- name = f"text_model.encoder.layers.{i}.self_attn"
- mod = layer.self_attn
- attn_modules.append((name, mod))
-
- return attn_modules
-
- for name, module in text_encoder_attn_modules(text_encoder):
- for k, v in module.q_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
-
- for k, v in module.k_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
-
- for k, v in module.v_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
-
- for k, v in module.out_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
-
- return state_dict
-
-
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
@@ -130,12 +98,6 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
parser.add_argument(
"--dataset_name",
type=str,
@@ -460,11 +422,9 @@ def main():
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
- vae = AutoencoderKL.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
- )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
@@ -498,43 +458,25 @@ def main():
# => 32 layers
# Set correct lora layers
- unet_lora_parameters = []
- for attn_processor_name, attn_processor in unet.attn_processors.items():
- # Parse the attention module.
- attn_module = unet
- for n in attn_processor_name.split(".")[:-1]:
- attn_module = getattr(attn_module, n)
-
- # Set the `lora_layer` attribute of the attention-related matrices.
- attn_module.to_q.set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
- )
- )
- attn_module.to_k.set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
- )
+ lora_attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+
+ lora_attn_procs[name] = LoRAAttnProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ rank=args.rank,
)
- attn_module.to_v.set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
- )
- )
- attn_module.to_out[0].set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_out[0].in_features,
- out_features=attn_module.to_out[0].out_features,
- rank=args.rank,
- )
- )
-
- # Accumulate the LoRA params to optimize.
- unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
- unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
- unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
- unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
+ unet.set_attn_processor(lora_attn_procs)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
@@ -549,6 +491,8 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
+ lora_layers = AttnProcsLayers(unet.attn_processors)
+
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
@@ -573,7 +517,7 @@ def main():
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
- unet_lora_parameters,
+ lora_layers.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -700,8 +644,8 @@ def collate_fn(examples):
)
# Prepare everything with our `accelerator`.
- unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ lora_layers, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -833,7 +777,7 @@ def collate_fn(examples):
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
- params_to_clip = unet_lora_parameters
+ params_to_clip = lora_layers.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
@@ -889,7 +833,6 @@ def collate_fn(examples):
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -946,7 +889,7 @@ def collate_fn(examples):
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
- args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
+ args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
)
pipeline = pipeline.to(accelerator.device)
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index e364ce65734d..249b9d1a9ab5 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -33,7 +33,7 @@
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
-from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
@@ -49,7 +49,7 @@
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
-from diffusers.loaders import LoraLoaderMixin
+from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
@@ -58,44 +58,11 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__)
-# TODO: This function should be removed once training scripts are rewritten in PEFT
-def text_encoder_lora_state_dict(text_encoder):
- state_dict = {}
-
- def text_encoder_attn_modules(text_encoder):
- from transformers import CLIPTextModel, CLIPTextModelWithProjection
-
- attn_modules = []
-
- if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
- for i, layer in enumerate(text_encoder.text_model.encoder.layers):
- name = f"text_model.encoder.layers.{i}.self_attn"
- mod = layer.self_attn
- attn_modules.append((name, mod))
-
- return attn_modules
-
- for name, module in text_encoder_attn_modules(text_encoder):
- for k, v in module.q_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
-
- for k, v in module.k_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
-
- for k, v in module.v_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
-
- for k, v in module.out_proj.lora_linear_layer.state_dict().items():
- state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
-
- return state_dict
-
-
def save_model_card(
repo_id: str,
images=None,
@@ -180,12 +147,6 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
parser.add_argument(
"--dataset_name",
type=str,
@@ -530,13 +491,12 @@ def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
- kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
- kwargs_handlers=[kwargs],
)
if args.report_to == "wandb":
@@ -576,16 +536,10 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer",
- revision=args.revision,
- use_fast=False,
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
)
tokenizer_two = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer_2",
- revision=args.revision,
- use_fast=False,
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
)
# import correct text encoder classes
@@ -599,10 +553,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
vae_path = (
args.pretrained_model_name_or_path
@@ -610,13 +564,10 @@ def main(args):
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
- vae_path,
- subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
- revision=args.revision,
- variant=args.variant,
+ vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
)
unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# We only train the additional adapter LoRA layers
@@ -813,7 +764,9 @@ def load_model_hook(models, input_dir):
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
- args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
)
else:
data_files = {}
@@ -1191,7 +1144,6 @@ def compute_time_ids(original_size, crops_coords_top_left):
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet),
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
@@ -1257,11 +1209,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
# Final inference
# Load previous pipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- vae=vae,
- revision=args.revision,
- variant=args.variant,
- torch_dtype=weight_dtype,
+ args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
)
pipeline = pipeline.to(accelerator.device)
diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py
index 0955d94b2202..c681943f2e94 100644
--- a/examples/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_sdxl.py
@@ -57,7 +57,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__)
@@ -148,12 +148,6 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
parser.add_argument(
"--dataset_name",
type=str,
@@ -624,16 +618,10 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer",
- revision=args.revision,
- use_fast=False,
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
)
tokenizer_two = AutoTokenizer.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="tokenizer_2",
- revision=args.revision,
- use_fast=False,
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
)
# import correct text encoder classes
@@ -648,10 +636,10 @@ def main(args):
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# Check for terminal SNR in combination with SNR Gamma
text_encoder_one = text_encoder_cls_one.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
vae_path = (
args.pretrained_model_name_or_path
@@ -659,13 +647,10 @@ def main(args):
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
- vae_path,
- subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
- revision=args.revision,
- variant=args.variant,
+ vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
)
unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# Freeze vae and text encoders.
@@ -692,7 +677,7 @@ def main(args):
# Create EMA for the unet.
if args.use_ema:
ema_unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
@@ -1053,6 +1038,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
+ prompt_embeds = prompt_embeds
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
).sample
@@ -1160,14 +1146,12 @@ def compute_time_ids(original_size, crops_coords_top_left):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
- variant=args.variant,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
if args.prediction_type is not None:
@@ -1215,16 +1199,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- unet=unet,
- vae=vae,
- revision=args.revision,
- variant=args.variant,
- torch_dtype=weight_dtype,
+ args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype
)
if args.prediction_type is not None:
scheduler_args = {"prediction_type": args.prediction_type}
diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md
index 0a1d8a459fc6..21bca526b5d2 100644
--- a/examples/textual_inversion/README.md
+++ b/examples/textual_inversion/README.md
@@ -25,12 +25,12 @@ cd diffusers
pip install .
```
-Then cd in the example folder and run:
+Then cd in the example folder and run
```bash
pip install -r requirements.txt
```
-And initialize an [🤗 Accelerate](https://github.com/huggingface/accelerate/) environment with:
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
@@ -56,7 +56,7 @@ snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="d
```
This will be our training data.
-Now we can launch the training using:
+Now we can launch the training using
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
@@ -68,14 +68,12 @@ accelerate launch textual_inversion.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATA_DIR \
--learnable_property="object" \
- --placeholder_token="" \
- --initializer_token="toy" \
+ --placeholder_token="" --initializer_token="toy" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--max_train_steps=3000 \
- --learning_rate=5.0e-04 \
- --scale_lr \
+ --learning_rate=5.0e-04 --scale_lr \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--push_to_hub \
@@ -87,10 +85,10 @@ A full training run takes ~1 hour on one V100 GPU.
**Note**: As described in [the official paper](https://arxiv.org/abs/2208.01618)
only one embedding vector is used for the placeholder token, *e.g.* `""`.
However, one can also add multiple embedding vectors for the placeholder token
-to increase the number of fine-tuneable parameters. This can help the model to learn
-more complex details. To use multiple embedding vectors, you should define `--num_vectors`
+to inclease the number of fine-tuneable parameters. This can help the model to learn
+more complex details. To use multiple embedding vectors, you can should define `--num_vectors`
to a number larger than one, *e.g.*:
-```bash
+```
--num_vectors 5
```
@@ -133,13 +131,11 @@ python textual_inversion_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATA_DIR \
--learnable_property="object" \
- --placeholder_token="" \
- --initializer_token="toy" \
+ --placeholder_token="" --initializer_token="toy" \
--resolution=512 \
--train_batch_size=1 \
--max_train_steps=3000 \
- --learning_rate=5.0e-04 \
- --scale_lr \
+ --learning_rate=5.0e-04 --scale_lr \
--output_dir="textual_inversion_cat"
```
It should be at least 70% faster than the PyTorch script with the same configuration.
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index 8531a5a5d5b6..01830751ffe2 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -79,7 +79,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__)
@@ -126,7 +126,6 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
vae=vae,
safety_checker=None,
revision=args.revision,
- variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -207,12 +206,6 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
- parser.add_argument(
- "--variant",
- type=str,
- default=None,
- help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
- )
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -631,11 +624,9 @@ def main():
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
- vae = AutoencoderKL.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
- )
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
- args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# Add the placeholder token in tokenizer
@@ -761,7 +752,6 @@ def main():
num_cycles=args.lr_num_cycles,
)
- text_encoder.train()
# Prepare everything with our `accelerator`.
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py
index 0af74bb5b25d..224c1147be9f 100644
--- a/examples/textual_inversion/textual_inversion_flax.py
+++ b/examples/textual_inversion/textual_inversion_flax.py
@@ -56,7 +56,7 @@
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py
index 9af8203b3138..a3baa3b85b36 100644
--- a/examples/unconditional_image_generation/train_unconditional.py
+++ b/examples/unconditional_image_generation/train_unconditional.py
@@ -6,6 +6,7 @@
import shutil
from datetime import timedelta
from pathlib import Path
+from typing import Optional
import accelerate
import datasets
@@ -15,7 +16,7 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset
-from huggingface_hub import create_repo, upload_folder
+from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version
from torchvision import transforms
from tqdm.auto import tqdm
@@ -29,7 +30,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.24.0")
+check_min_version("0.22.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -272,6 +273,16 @@ def parse_args():
return args
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
@@ -345,13 +356,21 @@ def load_model_hook(models, input_dir):
# Handle the repository creation
if accelerator.is_main_process:
- if args.output_dir is not None:
- os.makedirs(args.output_dir, exist_ok=True)
-
if args.push_to_hub:
- repo_id = create_repo(
- repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
- ).repo_id
+ if args.hub_model_id is None:
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
+ else:
+ repo_name = args.hub_model_id
+ create_repo(repo_name, exist_ok=True, token=args.hub_token)
+ repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
+
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
+ if "step_*" not in gitignore:
+ gitignore.write("step_*\n")
+ if "epoch_*" not in gitignore:
+ gitignore.write("epoch_*\n")
+ elif args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
# Initialize the model
if args.model_config_name_or_path is None:
@@ -394,14 +413,6 @@ def load_model_hook(models, input_dir):
model_config=model.config,
)
- weight_dtype = torch.float32
- if accelerator.mixed_precision == "fp16":
- weight_dtype = torch.float16
- args.mixed_precision = accelerator.mixed_precision
- elif accelerator.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
- args.mixed_precision = accelerator.mixed_precision
-
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
@@ -548,9 +559,11 @@ def transform_images(examples):
progress_bar.update(1)
continue
- clean_images = batch["input"].to(weight_dtype)
+ clean_images = batch["input"]
# Sample noise that we'll add to the images
- noise = torch.randn(clean_images.shape, dtype=weight_dtype, device=clean_images.device)
+ noise = torch.randn(
+ clean_images.shape, dtype=(torch.float32 if args.mixed_precision == "no" else torch.float16)
+ ).to(clean_images.device)
bsz = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
@@ -566,14 +579,15 @@ def transform_images(examples):
model_output = model(noisy_images, timesteps).sample
if args.prediction_type == "epsilon":
- loss = F.mse_loss(model_output.float(), noise.float()) # this could have different weights!
+ loss = F.mse_loss(model_output, noise) # this could have different weights!
elif args.prediction_type == "sample":
alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
)
snr_weights = alpha_t / (1 - alpha_t)
- # use SNR weighting from distillation paper
- loss = snr_weights * F.mse_loss(model_output.float(), clean_images.float(), reduction="none")
+ loss = snr_weights * F.mse_loss(
+ model_output, clean_images, reduction="none"
+ ) # use SNR weighting from distillation paper
loss = loss.mean()
else:
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
@@ -689,12 +703,7 @@ def transform_images(examples):
ema_model.restore(unet.parameters())
if args.push_to_hub:
- upload_folder(
- repo_id=repo_id,
- folder_path=args.output_dir,
- commit_message=f"Epoch {epoch}",
- ignore_patterns=["step_*", "epoch_*"],
- )
+ repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
accelerator.end_training()
diff --git a/my_datasets/dataset_inpainting.py b/my_datasets/dataset_inpainting.py
new file mode 100644
index 000000000000..ed1135baa5f4
--- /dev/null
+++ b/my_datasets/dataset_inpainting.py
@@ -0,0 +1,451 @@
+import glob
+import logging
+import os
+import random
+
+import albumentations as A
+import cv2
+import io
+import numpy as np
+import hashlib
+import lmdb
+import pickle
+import torch
+import torch.nn.functional as F
+from torch.utils.data import Dataset, IterableDataset, DataLoader, DistributedSampler, ConcatDataset
+import torchvision.transforms as transforms
+from torchvision.transforms import Compose, ToTensor, Normalize, RandomResizedCrop
+from PIL import Image
+from enum import Enum
+import csv
+import pandas as pd
+import json
+
+
+LOGGER = logging.getLogger(__name__)
+
+class DrawMethod(Enum):
+ LINE = 'line'
+ CIRCLE = 'circle'
+ SQUARE = 'square'
+
+class LinearRamp:
+ def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
+ self.start_value = start_value
+ self.end_value = end_value
+ self.start_iter = start_iter
+ self.end_iter = end_iter
+
+ def __call__(self, i):
+ if i < self.start_iter:
+ return self.start_value
+ if i >= self.end_iter:
+ return self.end_value
+ part = (i - self.start_iter) / (self.end_iter - self.start_iter)
+ return self.start_value * (1 - part) + self.end_value * part
+
+def loads_data(buf):
+ """
+ Args:
+ buf: the output of `dumps`.
+ """
+ return pickle.loads(buf)
+
+
+def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10,
+ draw_method=DrawMethod.LINE):
+ draw_method = DrawMethod(draw_method)
+
+ height, width = shape
+ mask = np.zeros((height, width), np.float32)
+ times = np.random.randint(min_times, max_times + 1)
+ for i in range(times):
+ start_x = np.random.randint(width)
+ start_y = np.random.randint(height)
+ for j in range(1 + np.random.randint(5)):
+ angle = 0.01 + np.random.randint(max_angle)
+ if i % 2 == 0:
+ angle = 2 * 3.1415926 - angle
+ length = 10 + np.random.randint(max_len)
+ brush_w = 5 + np.random.randint(max_width)
+ end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
+ end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
+ if draw_method == DrawMethod.LINE:
+ cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
+ elif draw_method == DrawMethod.CIRCLE:
+ cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1)
+ elif draw_method == DrawMethod.SQUARE:
+ radius = brush_w // 2
+ mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1
+ start_x, start_y = end_x, end_y
+ return mask[..., None]
+
+
+class RandomIrregularMaskGenerator:
+ def __init__(self, max_angle=4, max_len=400, max_width=300, min_times=2, max_times=5, ramp_kwargs=None,
+ draw_method=DrawMethod.LINE):
+ self.max_angle = max_angle
+ self.max_len = max_len
+ self.max_width = max_width
+ self.min_times = min_times
+ self.max_times = max_times
+ self.draw_method = draw_method
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
+
+ def __call__(self, img, iter_i=None):
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
+ cur_max_len = int(max(1, self.max_len * coef))
+ cur_max_width = int(max(1, self.max_width * coef))
+ cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef)
+ return make_random_irregular_mask(img.shape[1:], max_angle=self.max_angle, max_len=cur_max_len,
+ max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times,
+ draw_method=self.draw_method)
+
+def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3):
+ height, width = shape
+ mask = np.zeros((height, width), np.float32)
+ bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
+ times = np.random.randint(min_times, max_times + 1)
+ for i in range(times):
+ box_width = np.random.randint(bbox_min_size, bbox_max_size)
+ box_height = np.random.randint(bbox_min_size, bbox_max_size)
+ start_x = np.random.randint(margin, width - margin - box_width + 1)
+ start_y = np.random.randint(margin, height - margin - box_height + 1)
+ mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
+ return mask[..., None]
+
+
+class RandomRectangleMaskGenerator:
+ def __init__(self, margin=10, bbox_min_size=250, bbox_max_size=360, min_times=2, max_times=4, ramp_kwargs=None):
+ self.margin = margin
+ self.bbox_min_size = bbox_min_size
+ self.bbox_max_size = bbox_max_size
+ self.min_times = min_times
+ self.max_times = max_times
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
+
+ def __call__(self, img, iter_i=None):
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
+ cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef)
+ cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef)
+ return make_random_rectangle_mask(img.shape[1:], margin=self.margin, bbox_min_size=self.bbox_min_size,
+ bbox_max_size=cur_bbox_max_size, min_times=self.min_times,
+ max_times=cur_max_times)
+
+class HumanSegMaskGenerator:
+ def __init__(self, root):
+ self.mask_files = list(glob.glob(os.path.join(root, '**', '*.*'), recursive=True))
+ self.lens = len(self.mask_files)
+ self.mask_aug = transforms.RandomChoice(
+ [
+ transforms.RandomRotation((-30, 30), fill=(0,)),
+ transforms.RandomHorizontalFlip(),
+ ]
+ )
+
+ def __call__(self, image, iter_i):
+ ipt_h, ipt_w = image.shape[1:]
+ mask_index = random.randint(0, self.lens - 1)
+ maskname = self.mask_files[mask_index]
+ mask = Image.open(maskname).convert('L')
+ mask = np.array(self.mask_aug(mask))
+ h, w = mask.shape[:2]
+ ratio = min(ipt_h/h, ipt_w/w)
+ scale_w = int(w * ratio)
+ scale_h = int(h * ratio)
+ mask = cv2.resize(mask, (scale_w, scale_h), interpolation=cv2.INTER_NEAREST)
+ height_pad = ipt_h - scale_h
+ top = random.randint(0, height_pad)
+ bottom = height_pad - top
+ width_pad = ipt_w - scale_w
+ right = random.randint(0, width_pad)
+ left = width_pad - right
+ mask = np.pad(mask, ((top, bottom), (right, left)), mode='constant', constant_values=0)
+ mask = (mask > 0).astype(np.float32)
+ return mask[..., None]
+
+
+class OutpaintingMaskGenerator:
+ def __init__(self, min_padding_percent:float=0.2, max_padding_percent:int=0.5, left_padding_prob:float=0.6, top_padding_prob:float=0.6,
+ right_padding_prob:float=0.6, bottom_padding_prob:float=0.6, is_fixed_randomness:bool=False):
+ """
+ is_fixed_randomness - get identical paddings for the same image if args are the same
+ """
+ self.min_padding_percent = min_padding_percent
+ self.max_padding_percent = max_padding_percent
+ self.probs = [left_padding_prob, top_padding_prob, right_padding_prob, bottom_padding_prob]
+ self.is_fixed_randomness = is_fixed_randomness
+
+ assert self.min_padding_percent <= self.max_padding_percent
+ assert self.max_padding_percent > 0
+ assert len([x for x in [self.min_padding_percent, self.max_padding_percent] if (x>=0 and x<=1)]) == 2, f"Padding percentage should be in [0,1]"
+ assert sum(self.probs) > 0, f"At least one of the padding probs should be greater than 0 - {self.probs}"
+ assert len([x for x in self.probs if (x >= 0) and (x <= 1)]) == 4, f"At least one of padding probs is not in [0,1] - {self.probs}"
+ if len([x for x in self.probs if x > 0]) == 1:
+ LOGGER.warning(f"Only one padding prob is greater than zero - {self.probs}. That means that the outpainting masks will be always on the same side")
+
+ def apply_padding(self, mask, coord):
+ mask[int(coord[0][0]*self.img_h):int(coord[1][0]*self.img_h),
+ int(coord[0][1]*self.img_w):int(coord[1][1]*self.img_w)] = 1
+ return mask
+
+ def get_padding(self, size):
+ n1 = int(self.min_padding_percent*size)
+ n2 = int(self.max_padding_percent*size)
+ return self.rnd.randint(n1, n2) / size
+
+ @staticmethod
+ def _img2rs(img):
+ arr = np.ascontiguousarray(img.astype(np.uint8))
+ str_hash = hashlib.sha1(arr).hexdigest()
+ res = hash(str_hash)%(2**32)
+ return res
+
+ def __call__(self, img, iter_i=None, raw_image=None):
+ c, self.img_h, self.img_w = img.shape
+ mask = np.zeros((self.img_h, self.img_w), np.float32)
+ at_least_one_mask_applied = False
+
+ if self.is_fixed_randomness:
+ assert raw_image is not None, f"Cant calculate hash on raw_image=None"
+ rs = self._img2rs(raw_image)
+ self.rnd = np.random.RandomState(rs)
+ else:
+ self.rnd = np.random
+
+ coords = [[
+ (0,0),
+ (1,self.get_padding(size=self.img_h))
+ ],
+ [
+ (0,0),
+ (self.get_padding(size=self.img_w),1)
+ ],
+ [
+ (0,1-self.get_padding(size=self.img_h)),
+ (1,1)
+ ],
+ [
+ (1-self.get_padding(size=self.img_w),0),
+ (1,1)
+ ]]
+
+ for pp, coord in zip(self.probs, coords):
+ if self.rnd.random() < pp:
+ at_least_one_mask_applied = True
+ mask = self.apply_padding(mask=mask, coord=coord)
+
+ if not at_least_one_mask_applied:
+ idx = self.rnd.choice(range(len(coords)), p=np.array(self.probs)/sum(self.probs))
+ mask = self.apply_padding(mask=mask, coord=coords[idx])
+ return mask[..., None]
+
+
+class MixedMaskGenerator:
+ def __init__(self, irregular_proba=1/3, irregular_kwargs=None,
+ box_proba=1/3, box_kwargs=None,
+ human_proba=0, human_mask_root=None,
+ outpainting_proba=0, outpainting_kwargs=None,
+ blank_mask_proba=0,
+ invert_proba=0):
+ self.probas = []
+ self.gens = []
+ self.blank_mask_proba = blank_mask_proba
+
+ if irregular_proba > 0:
+ self.probas.append(irregular_proba)
+ if irregular_kwargs is None:
+ irregular_kwargs = {}
+ else:
+ irregular_kwargs = dict(irregular_kwargs)
+ irregular_kwargs['draw_method'] = DrawMethod.LINE
+ self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs))
+
+ if box_proba > 0:
+ self.probas.append(box_proba)
+ if box_kwargs is None:
+ box_kwargs = {}
+ self.gens.append(RandomRectangleMaskGenerator(**box_kwargs))
+
+ if human_proba > 0:
+ assert os.path.exists(human_mask_root)
+ self.probas.append(human_proba)
+ self.gens.append(HumanSegMaskGenerator(human_mask_root))
+
+ if outpainting_proba > 0:
+ self.probas.append(outpainting_proba)
+ if outpainting_kwargs is None:
+ outpainting_kwargs = {}
+ self.gens.append(OutpaintingMaskGenerator(**outpainting_kwargs))
+
+ self.probas = np.array(self.probas, dtype='float32')
+ self.probas /= self.probas.sum()
+ self.invert_proba = invert_proba
+
+ def __call__(self, img, iter_i=None):
+ if np.random.random() < self.blank_mask_proba: # mask everything, for sd
+ result = np.ones(img.shape[1:])
+ return result[..., None]
+ kind = np.random.choice(len(self.probas), p=self.probas)
+ gen = self.gens[kind]
+ result = gen(img, iter_i=iter_i)
+ if self.invert_proba > 0 and random.random() < self.invert_proba:
+ result = 1 - result
+ return result
+
+class LoadImageFromLmdb(object):
+ def __init__(self, lmdb_path):
+ self.lmdb_path = lmdb_path
+ self.txn = None
+
+ def __call__(self, key):
+ if self.txn is None:
+ env = lmdb.open(self.lmdb_path, max_readers=4,
+ readonly=True, lock=False,
+ readahead=True, meminit=False)
+ self.txn = env.begin(write=False)
+ image_buf = self.txn.get(key.encode())
+ with Image.open(io.BytesIO(image_buf)) as image:
+ if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
+ image = image.convert("RGBA")
+ white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
+ white.paste(image, mask=image.split()[3])
+ image = white
+ else:
+ image = image.convert("RGB")
+ return image
+
+
+class InpaintingTextTrainDataset(Dataset):
+ def __init__(self, indir, args=None, mask_gen_kwargs=None):
+ self.blank_mask_prob=args.blank_mask_prob
+ if mask_gen_kwargs==None:
+ mask_gen_kwargs = {
+ "irregular_proba": 0.25,
+ "irregular_kwargs":{
+ "max_angle":4,
+ "max_len": 240,
+ "max_width": 100,
+ "max_times": 4 ,
+ "min_times": 1},
+ "box_proba": 0.25,
+ "box_kwargs": {
+ "margin": 10,
+ "bbox_min_size": 35,
+ "bbox_max_size": 160,
+ "max_times": 4,
+ "min_times": 1
+ },
+ "outpainting_proba": 0.5,
+ "outpainting_kwargs": {
+ "min_padding_percent": 0.25,
+ "max_padding_percent": 0.4,
+ "left_padding_prob": 0.5,
+ "top_padding_prob": 0.5,
+ "right_padding_prob": 0.5,
+ "bottom_padding_prob": 0.5
+ }
+ }
+
+ indir2 = os.path.join(indir,"LLMGA-dataset","coco2017_train.json")
+ image_folder2 = os.path.join(indir,"COCO","train2017")
+
+ self.txn1 = LoadImageFromLmdb(os.path.join(indir, "LAION-Aesthetic", "lmdb_train-00000-of-00002"))
+ self.txn2 = LoadImageFromLmdb(os.path.join(indir, "LAION-Aesthetic", "lmdb_train-00001-of-00002"))
+
+ with open(os.path.join(indir,"LLMGA-dataset","LAION","lmdb_train-00000-of-00002.json"), 'r', encoding='utf-8') as fr:
+ self.prompt_dict1 = json.load(fr)
+
+ with open(os.path.join(indir,"LLMGA-dataset","LAION","lmdb_train-00001-of-00002.json"), 'r', encoding='utf-8') as fr:
+ self.prompt_dict2 = json.load(fr)
+
+ with open(os.path.join(indir,"LLMGA-dataset","LAION","laion_3m_prompt.json"), 'r', encoding='utf-8') as fr:
+ self.prompt_dict_ori = json.load(fr)
+
+ self.prompt_dict3 = json.load(open(indir2,"r"))
+
+ self.image_folder2=image_folder2
+
+ self.train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution), #if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ self.mask_generator = MixedMaskGenerator(**mask_gen_kwargs)
+
+ self.len_1=len(self.prompt_dict1)
+ self.len_2=len(self.prompt_dict2)
+ self.len_3=len(self.prompt_dict3)
+
+
+ def __len__(self):
+ return self.len_1+self.len_2+ self.len_3
+
+ def __getitem__(self, index):
+ if index=0.5]=1
+ if random.random()<0.25:
+ mask=torch.ones_like(mask)
+ masked_img=img*(1-mask)
+ res = {
+ "pixel_values": img,
+ "caption": prompt,
+ "mask": mask,
+ "masked_image": masked_img,
+ }
+ return res
+
+
+
+
+if __name__ == "__main__":
+ irregular_kwargs = {
+ "max_angle": 4, "max_len": 400, "max_width": 300, "max_times": 5, "min_times": 2
+ }
+ box_kwargs = {
+ "margin": 10, "bbox_min_size": 250, "bbox_max_size": 360, "max_times": 4, "min_times": 3
+ }
+ outpainting_kwargs = {"min_padding_percent":0.2, "max_padding_percent":0.5, "left_padding_prob":0.6, "top_padding_prob":0.6, "right_padding_prob":0.6, "bottom_padding_prob":0.6}
+ gens = MixedMaskGenerator(irregular_proba=0.2, irregular_kwargs=irregular_kwargs,
+ box_proba=0.2, box_kwargs=box_kwargs,
+ human_proba=0, human_mask_root=None,
+ outpainting_proba=0.35, outpainting_kwargs=outpainting_kwargs,
+ blank_mask_proba=0.25)
+ img = np.zeros((512, 512, 3))
+ save_dir = "debug_mask"
+ os.makedirs(save_dir, exist_ok=True)
+ for i in range(4):
+ mask_list = []
+ for _ in range(8):
+ mask = gens(np.transpose(img, (2, 0, 1)))
+ mask_list.append(mask*255)
+ cv2.imwrite(f"{save_dir}/zmask_{i}.jpg", np.hstack(mask_list))
+
diff --git a/my_datasets/dataset_inpainting_sdxl.py b/my_datasets/dataset_inpainting_sdxl.py
new file mode 100644
index 000000000000..a15a8fb42079
--- /dev/null
+++ b/my_datasets/dataset_inpainting_sdxl.py
@@ -0,0 +1,439 @@
+import glob
+import logging
+import os
+import random
+
+import cv2
+import io
+import numpy as np
+import hashlib
+import lmdb
+import pickle
+import torch
+import torch.nn.functional as F
+from torch.utils.data import Dataset, IterableDataset, DataLoader, DistributedSampler, ConcatDataset
+import torchvision.transforms as transforms
+from torchvision.transforms import Compose, ToTensor, Normalize, RandomResizedCrop
+from PIL import Image
+from enum import Enum
+import csv
+import pandas as pd
+import json
+import albumentations as A
+
+LOGGER = logging.getLogger(__name__)
+
+class DrawMethod(Enum):
+ LINE = 'line'
+ CIRCLE = 'circle'
+ SQUARE = 'square'
+
+class LinearRamp:
+ def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
+ self.start_value = start_value
+ self.end_value = end_value
+ self.start_iter = start_iter
+ self.end_iter = end_iter
+
+ def __call__(self, i):
+ if i < self.start_iter:
+ return self.start_value
+ if i >= self.end_iter:
+ return self.end_value
+ part = (i - self.start_iter) / (self.end_iter - self.start_iter)
+ return self.start_value * (1 - part) + self.end_value * part
+
+def loads_data(buf):
+ """
+ Args:
+ buf: the output of `dumps`.
+ """
+ return pickle.loads(buf)
+
+
+def make_random_irregular_mask(shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10,
+ draw_method=DrawMethod.LINE):
+ draw_method = DrawMethod(draw_method)
+
+ height, width = shape
+ mask = np.zeros((height, width), np.float32)
+ times = np.random.randint(min_times, max_times + 1)
+ for i in range(times):
+ start_x = np.random.randint(width)
+ start_y = np.random.randint(height)
+ for j in range(1 + np.random.randint(5)):
+ angle = 0.01 + np.random.randint(max_angle)
+ if i % 2 == 0:
+ angle = 2 * 3.1415926 - angle
+ length = 10 + np.random.randint(max_len)
+ brush_w = 5 + np.random.randint(max_width)
+ end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width)
+ end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height)
+ if draw_method == DrawMethod.LINE:
+ cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w)
+ elif draw_method == DrawMethod.CIRCLE:
+ cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1., thickness=-1)
+ elif draw_method == DrawMethod.SQUARE:
+ radius = brush_w // 2
+ mask[start_y - radius:start_y + radius, start_x - radius:start_x + radius] = 1
+ start_x, start_y = end_x, end_y
+ return mask[..., None]
+
+
+class RandomIrregularMaskGenerator:
+ def __init__(self, max_angle=4, max_len=400, max_width=300, min_times=2, max_times=5, ramp_kwargs=None,
+ draw_method=DrawMethod.LINE):
+ self.max_angle = max_angle
+ self.max_len = max_len
+ self.max_width = max_width
+ self.min_times = min_times
+ self.max_times = max_times
+ self.draw_method = draw_method
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
+
+ def __call__(self, img, iter_i=None):
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
+ cur_max_len = int(max(1, self.max_len * coef))
+ cur_max_width = int(max(1, self.max_width * coef))
+ cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef)
+ return make_random_irregular_mask(img.shape[1:], max_angle=self.max_angle, max_len=cur_max_len,
+ max_width=cur_max_width, min_times=self.min_times, max_times=cur_max_times,
+ draw_method=self.draw_method)
+
+def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3):
+ height, width = shape
+ mask = np.zeros((height, width), np.float32)
+ bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2)
+ times = np.random.randint(min_times, max_times + 1)
+ for i in range(times):
+ box_width = np.random.randint(bbox_min_size, bbox_max_size)
+ box_height = np.random.randint(bbox_min_size, bbox_max_size)
+ start_x = np.random.randint(margin, width - margin - box_width + 1)
+ start_y = np.random.randint(margin, height - margin - box_height + 1)
+ mask[start_y:start_y + box_height, start_x:start_x + box_width] = 1
+ return mask[..., None]
+
+
+class RandomRectangleMaskGenerator:
+ def __init__(self, margin=10, bbox_min_size=250, bbox_max_size=360, min_times=2, max_times=4, ramp_kwargs=None):
+ self.margin = margin
+ self.bbox_min_size = bbox_min_size
+ self.bbox_max_size = bbox_max_size
+ self.min_times = min_times
+ self.max_times = max_times
+ self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None
+
+ def __call__(self, img, iter_i=None):
+ coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1
+ cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef)
+ cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef)
+ return make_random_rectangle_mask(img.shape[1:], margin=self.margin, bbox_min_size=self.bbox_min_size,
+ bbox_max_size=cur_bbox_max_size, min_times=self.min_times,
+ max_times=cur_max_times)
+
+class HumanSegMaskGenerator:
+ def __init__(self, root):
+ self.mask_files = list(glob.glob(os.path.join(root, '**', '*.*'), recursive=True))
+ self.lens = len(self.mask_files)
+ self.mask_aug = transforms.RandomChoice(
+ [
+ transforms.RandomRotation((-30, 30), fill=(0,)),
+ transforms.RandomHorizontalFlip(),
+ ]
+ )
+
+ def __call__(self, image, iter_i):
+ ipt_h, ipt_w = image.shape[1:]
+ mask_index = random.randint(0, self.lens - 1)
+ maskname = self.mask_files[mask_index]
+ mask = Image.open(maskname).convert('L')
+ mask = np.array(self.mask_aug(mask))
+ h, w = mask.shape[:2]
+ ratio = min(ipt_h/h, ipt_w/w)
+ scale_w = int(w * ratio)
+ scale_h = int(h * ratio)
+ mask = cv2.resize(mask, (scale_w, scale_h), interpolation=cv2.INTER_NEAREST)
+ height_pad = ipt_h - scale_h
+ top = random.randint(0, height_pad)
+ bottom = height_pad - top
+ width_pad = ipt_w - scale_w
+ right = random.randint(0, width_pad)
+ left = width_pad - right
+ mask = np.pad(mask, ((top, bottom), (right, left)), mode='constant', constant_values=0)
+ mask = (mask > 0).astype(np.float32)
+ return mask[..., None]
+
+
+# {"min_padding_percent":0.2, "max_padding_percent":0.5, "left_padding_prob":0.6, "top_padding_prob":0.6, "right_padding_prob":0.6, "bottom_padding_prob":0.6}
+class OutpaintingMaskGenerator:
+ def __init__(self, min_padding_percent:float=0.2, max_padding_percent:int=0.5, left_padding_prob:float=0.6, top_padding_prob:float=0.6,
+ right_padding_prob:float=0.6, bottom_padding_prob:float=0.6, is_fixed_randomness:bool=False):
+ """
+ is_fixed_randomness - get identical paddings for the same image if args are the same
+ """
+ self.min_padding_percent = min_padding_percent
+ self.max_padding_percent = max_padding_percent
+ self.probs = [left_padding_prob, top_padding_prob, right_padding_prob, bottom_padding_prob]
+ self.is_fixed_randomness = is_fixed_randomness
+
+ assert self.min_padding_percent <= self.max_padding_percent
+ assert self.max_padding_percent > 0
+ assert len([x for x in [self.min_padding_percent, self.max_padding_percent] if (x>=0 and x<=1)]) == 2, f"Padding percentage should be in [0,1]"
+ assert sum(self.probs) > 0, f"At least one of the padding probs should be greater than 0 - {self.probs}"
+ assert len([x for x in self.probs if (x >= 0) and (x <= 1)]) == 4, f"At least one of padding probs is not in [0,1] - {self.probs}"
+ if len([x for x in self.probs if x > 0]) == 1:
+ LOGGER.warning(f"Only one padding prob is greater than zero - {self.probs}. That means that the outpainting masks will be always on the same side")
+
+ def apply_padding(self, mask, coord):
+ mask[int(coord[0][0]*self.img_h):int(coord[1][0]*self.img_h),
+ int(coord[0][1]*self.img_w):int(coord[1][1]*self.img_w)] = 1
+ return mask
+
+ def get_padding(self, size):
+ n1 = int(self.min_padding_percent*size)
+ n2 = int(self.max_padding_percent*size)
+ return self.rnd.randint(n1, n2) / size
+
+ @staticmethod
+ def _img2rs(img):
+ arr = np.ascontiguousarray(img.astype(np.uint8))
+ str_hash = hashlib.sha1(arr).hexdigest()
+ res = hash(str_hash)%(2**32)
+ return res
+
+ def __call__(self, img, iter_i=None, raw_image=None):
+ c, self.img_h, self.img_w = img.shape
+ mask = np.zeros((self.img_h, self.img_w), np.float32)
+ at_least_one_mask_applied = False
+
+ if self.is_fixed_randomness:
+ assert raw_image is not None, f"Cant calculate hash on raw_image=None"
+ rs = self._img2rs(raw_image)
+ self.rnd = np.random.RandomState(rs)
+ else:
+ self.rnd = np.random
+
+ coords = [[
+ (0,0),
+ (1,self.get_padding(size=self.img_h))
+ ],
+ [
+ (0,0),
+ (self.get_padding(size=self.img_w),1)
+ ],
+ [
+ (0,1-self.get_padding(size=self.img_h)),
+ (1,1)
+ ],
+ [
+ (1-self.get_padding(size=self.img_w),0),
+ (1,1)
+ ]]
+
+ for pp, coord in zip(self.probs, coords):
+ if self.rnd.random() < pp:
+ at_least_one_mask_applied = True
+ mask = self.apply_padding(mask=mask, coord=coord)
+
+ if not at_least_one_mask_applied:
+ idx = self.rnd.choice(range(len(coords)), p=np.array(self.probs)/sum(self.probs))
+ mask = self.apply_padding(mask=mask, coord=coords[idx])
+ return mask[..., None]
+
+
+class MixedMaskGenerator:
+ def __init__(self, irregular_proba=1/3, irregular_kwargs=None,
+ box_proba=1/3, box_kwargs=None,
+ human_proba=0, human_mask_root=None,
+ outpainting_proba=0, outpainting_kwargs=None,
+ blank_mask_proba=0,
+ invert_proba=0):
+ self.probas = []
+ self.gens = []
+ self.blank_mask_proba = blank_mask_proba
+
+ if irregular_proba > 0:
+ self.probas.append(irregular_proba)
+ if irregular_kwargs is None:
+ irregular_kwargs = {}
+ else:
+ irregular_kwargs = dict(irregular_kwargs)
+ irregular_kwargs['draw_method'] = DrawMethod.LINE
+ self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs))
+
+ if box_proba > 0:
+ self.probas.append(box_proba)
+ if box_kwargs is None:
+ box_kwargs = {}
+ self.gens.append(RandomRectangleMaskGenerator(**box_kwargs))
+
+ if human_proba > 0:
+ assert os.path.exists(human_mask_root)
+ self.probas.append(human_proba)
+ self.gens.append(HumanSegMaskGenerator(human_mask_root))
+
+ if outpainting_proba > 0:
+ self.probas.append(outpainting_proba)
+ if outpainting_kwargs is None:
+ outpainting_kwargs = {}
+ self.gens.append(OutpaintingMaskGenerator(**outpainting_kwargs))
+
+ self.probas = np.array(self.probas, dtype='float32')
+ self.probas /= self.probas.sum()
+ self.invert_proba = invert_proba
+
+ def __call__(self, img, iter_i=None):
+ if np.random.random() < self.blank_mask_proba: # mask everything, for sd
+ result = np.ones(img.shape[1:])
+ return result[..., None]
+ kind = np.random.choice(len(self.probas), p=self.probas)
+ gen = self.gens[kind]
+ result = gen(img, iter_i=iter_i)
+ if self.invert_proba > 0 and random.random() < self.invert_proba:
+ result = 1 - result
+ return result
+
+
+class LoadImageFromLmdb(object):
+ def __init__(self, lmdb_path):
+ self.lmdb_path = lmdb_path
+ self.txn = None
+
+ def __call__(self, key):
+ if self.txn is None:
+ env = lmdb.open(self.lmdb_path, max_readers=4,
+ readonly=True, lock=False,
+ readahead=True, meminit=False)
+ self.txn = env.begin(write=False)
+ image_buf = self.txn.get(key.encode())
+ with Image.open(io.BytesIO(image_buf)) as image:
+ if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
+ image = image.convert("RGBA")
+ white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
+ white.paste(image, mask=image.split()[3])
+ image = white
+ else:
+ image = image.convert("RGB")
+ return image
+
+
+
+
+
+class InpaintingTextTrainDataset(Dataset):
+ def __init__(self, indir, args=None,mask_gen_kwargs=None):
+ self.txn1 = LoadImageFromLmdb(os.path.join(indir, "LAION-Aesthetic", "lmdb_train-00000-of-00002"))
+ self.txn2 = LoadImageFromLmdb(os.path.join(indir, "LAION-Aesthetic", "lmdb_train-00001-of-00002"))
+
+ with open(os.path.join(indir,"LLMGA-dataset","LAION","lmdb_train-00000-of-00002.json"), 'r', encoding='utf-8') as fr:
+ self.prompt_dict1 = json.load(fr)
+
+ with open(os.path.join(indir,"LLMGA-dataset","LAION","lmdb_train-00001-of-00002.json"), 'r', encoding='utf-8') as fr:
+ self.prompt_dict2 = json.load(fr)
+
+ with open(os.path.join(indir,"LLMGA-dataset","LAION","laion_3m_prompt.json"), 'r', encoding='utf-8') as fr:
+ self.prompt_dict_ori = json.load(fr)
+
+ self.args = args
+
+ self.train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ self.train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
+ self.train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ self.train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
+ if mask_gen_kwargs==None:
+ mask_gen_kwargs = {
+ "irregular_proba": 0.25,
+ "irregular_kwargs":{
+ "max_angle":4,
+ "max_len": 240,
+ "max_width": 100,
+ "max_times": 4 ,
+ "min_times": 1},
+ "box_proba": 0.25,
+ "box_kwargs": {
+ "margin": 10,
+ "bbox_min_size": 35,
+ "bbox_max_size": 160,
+ "max_times": 4,
+ "min_times": 1
+ },
+ "outpainting_proba": 0.5,
+ "outpainting_kwargs": {
+ "min_padding_percent": 0.25,
+ "max_padding_percent": 0.4,
+ "left_padding_prob": 0.5,
+ "top_padding_prob": 0.5,
+ "right_padding_prob": 0.5,
+ "bottom_padding_prob": 0.5
+ }
+ }
+ self.mask_generator = MixedMaskGenerator(**mask_gen_kwargs)
+ self.len_1=len(self.prompt_dict1)
+ self.len_2=len(self.prompt_dict2)
+
+ def preprocess_train(self, examples):
+ # image aug
+ image = examples["pixel_values"]
+ original_sizes=(image.height, image.width)
+ image = self.train_resize(image)
+ # crop_top_lefts=[]
+ y1 = max(0, int(round((image.height - self.args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - self.args.resolution) / 2.0)))
+ image = self.train_crop(image)
+ if self.args.random_flip and random.random() < 0.5:
+ # flip
+ x1 = image.width - x1
+ image = self.train_flip(image)
+ crop_top_left = (y1, x1)
+ crop_top_lefts=crop_top_left
+ image = self.train_transforms(image)
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ examples["pixel_values"] = image
+ return examples
+
+ def __len__(self):
+ return self.len_1+self.len_2
+
+ def __getitem__(self, index):
+ if index < self.len_1:
+ txn = self.txn1
+ keys = self.prompt_dict1
+ else:
+ txn = self.txn2
+ keys = self.prompt_dict2
+ index = index - self.len_1
+ key = keys[index]["image"]
+ img = txn(key)
+ if random.random()<0.05:
+ prompt=self.prompt_dict_ori[key]
+ else:
+ prompt = keys[index]["conversations"][1]["value"]
+
+ if random.random()<0.05:
+ prompt=""
+ examples = {
+ "pixel_values": img,
+ "caption": prompt,
+ }
+ examples = self.preprocess_train(examples)
+ mask = self.mask_generator(examples["pixel_values"], iter_i=index)
+ mask = torch.from_numpy(mask).permute(2, 0, 1)
+ mask[mask<0.5]=0
+ mask[mask>=0.5]=1
+ if random.random()<0.25:
+ mask=torch.ones_like(mask)
+ masked_img=examples["pixel_values"]*(1-mask)
+ examples.update({
+ "mask": mask,
+ "masked_image": masked_img,
+ })
+
+ return examples
+
+
+
+
+
+
+
diff --git a/my_datasets/dataset_text2img.py b/my_datasets/dataset_text2img.py
new file mode 100644
index 000000000000..e79c6e22ec97
--- /dev/null
+++ b/my_datasets/dataset_text2img.py
@@ -0,0 +1,125 @@
+import glob
+import logging
+import os
+import random
+
+import cv2
+import io
+import numpy as np
+import hashlib
+import lmdb
+import pickle
+import torch
+import torch.nn.functional as F
+from torch.utils.data import Dataset, IterableDataset, DataLoader, DistributedSampler, ConcatDataset
+import torchvision.transforms as transforms
+from torchvision.transforms import Compose, ToTensor, Normalize, RandomResizedCrop
+from PIL import Image
+from enum import Enum
+import csv
+import pandas as pd
+import json
+
+
+
+
+class LoadImageFromLmdb(object):
+ def __init__(self, lmdb_path):
+ self.lmdb_path = lmdb_path
+ self.txn = None
+
+ def __call__(self, key):
+ if self.txn is None:
+ env = lmdb.open(self.lmdb_path, max_readers=4,
+ readonly=True, lock=False,
+ readahead=True, meminit=False)
+ self.txn = env.begin(write=False)
+ image_buf = self.txn.get(key.encode())
+ with Image.open(io.BytesIO(image_buf)) as image:
+ if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
+ image = image.convert("RGBA")
+ white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
+ white.paste(image, mask=image.split()[3])
+ image = white
+ else:
+ image = image.convert("RGB")
+ return image
+
+
+
+
+
+class Text2ImgTrainDataset(Dataset):
+ def __init__(self, indir, args=None):
+ indir_coco = os.path.join(indir,"LLMGA-dataset","coco2017_train.json")
+ image_folder2 = os.path.join(indir,"COCO","train2017")
+ self.txn1 = LoadImageFromLmdb(os.path.join(indir, "LAION-Aesthetic", "lmdb_train-00000-of-00002"))
+ self.txn2 = LoadImageFromLmdb(os.path.join(indir, "LAION-Aesthetic", "lmdb_train-00001-of-00002"))
+
+ with open(os.path.join(indir,"LLMGA-dataset","LAION","lmdb_train-00000-of-00002.json"), 'r', encoding='utf-8') as fr:
+ self.prompt_dict1 = json.load(fr)
+
+ with open(os.path.join(indir,"LLMGA-dataset","LAION","lmdb_train-00001-of-00002.json"), 'r', encoding='utf-8') as fr:
+ self.prompt_dict2 = json.load(fr)
+
+ with open(os.path.join(indir,"LLMGA-dataset","LAION","laion_3m_prompt.json"), 'r', encoding='utf-8') as fr:
+ self.prompt_dict_ori = json.load(fr)
+
+ self.prompt_coco_dict = json.load(open(indir_coco,"r"))
+
+ self.image_folder2 = image_folder2
+
+ self.train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution), #if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ self.len_1=len(self.prompt_dict1)
+ self.len_2=len(self.prompt_dict2)
+ self.len_3=len(self.prompt_coco_dict)
+
+ def __len__(self):
+ return self.len_1+self.len_2+ self.len_3
+
+ def __getitem__(self, index):
+ if index>> parse_prompt_attention('normal text')
+ [['normal text', 1.0]]
+ >>> parse_prompt_attention('an (important) word')
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+ >>> parse_prompt_attention('(unbalanced')
+ [['unbalanced', 1.1]]
+ >>> parse_prompt_attention('\(literal\]')
+ [['(literal]', 1.0]]
+ >>> parse_prompt_attention('(unnecessary)(parens)')
+ [['unnecessaryparens', 1.1]]
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+ [['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]]
+ """
+
+ res = []
+ round_brackets = []
+ square_brackets = []
+
+ round_bracket_multiplier = 1.1
+ square_bracket_multiplier = 1 / 1.1
+
+ def multiply_range(start_position, multiplier):
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+
+ for m in re_attention.finditer(text):
+ text = m.group(0)
+ weight = m.group(1)
+
+ if text.startswith("\\"):
+ res.append([text[1:], 1.0])
+ elif text == "(":
+ round_brackets.append(len(res))
+ elif text == "[":
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif text == ")" and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif text == "]" and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ res.append([text, 1.0])
+
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+
+ if len(res) == 0:
+ res = [["", 1.0]]
+
+ # merge runs of identical weights
+ i = 0
+ while i + 1 < len(res):
+ if res[i][1] == res[i + 1][1]:
+ res[i][0] += res[i + 1][0]
+ res.pop(i + 1)
+ else:
+ i += 1
+
+ return res
+
+
+
+def pad_tokens(tokens, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
+ r"""
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
+ """
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
+ for i in range(len(tokens)):
+ tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
+
+ return tokens
+
+def get_unweighted_text_embeddings_SDXL2(
+ text_encoder,
+ text_input: torch.Tensor,
+ chunk_length: int,
+ no_boseos_middle: Optional[bool] = True,
+ clip_skip: Optional[int] = None,
+):
+ """
+ When the length of tokens is a multiple of the capacity of the text encoder,
+ it should be split into chunks and sent to the text encoder individually.
+ """
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
+ if max_embeddings_multiples > 1:
+ text_embeddings = None
+ max_ids = None
+
+ hidden_states_all = []
+ for i in range(max_embeddings_multiples):
+ # extract the i-th chunk
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
+
+ # cover the head and the tail by the starting and the ending tokens
+ text_input_chunk[:, 0] = text_input[0, 0]
+ text_input_chunk[:, -1] = text_input[0, -1]
+ prompt_embeds = text_encoder(text_input_chunk,output_hidden_states=True)
+
+ text_embedding = prompt_embeds[0]
+ # hidden_states = prompt_embeds.hidden_states
+ if clip_skip is None:
+ hidden_states = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ hidden_states = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ if no_boseos_middle:
+ if i == 0:
+ # discard the ending token
+ hidden_states = hidden_states[:, :-1]
+ elif i == max_embeddings_multiples - 1:
+ # discard the starting token
+ hidden_states = hidden_states[:, 1:]
+ else:
+ # discard both starting and ending tokens
+ hidden_states = hidden_states[:, 1:-1]
+
+ if text_embeddings is None:
+ text_embeddings=text_embedding
+ max_ids,_ = torch.max(text_input_chunk,dim=1,keepdim=True) #[B]
+ max_ids = max_ids.view(-1,1)
+ else:
+ now_max_ids,_ = torch.max(text_input_chunk,dim=1,keepdim=True)
+ now_max_ids = now_max_ids.view(-1,1)
+ text_embeddings = torch.where( max_ids>now_max_ids, text_embeddings,text_embedding)
+
+ hidden_states_all.append(hidden_states)
+ hidden_states = torch.concat(hidden_states_all, axis=1)
+ else:
+ prompt_embeds = text_encoder(text_input,output_hidden_states=True)
+ text_embeddings = prompt_embeds[0]
+ if clip_skip is None:
+ hidden_states = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ hidden_states = prompt_embeds.hidden_states[-(clip_skip + 2)]
+ return text_embeddings, hidden_states
+
+def get_unweighted_text_embeddings_SDXL1(
+ text_encoder,
+ text_input: torch.Tensor,
+ chunk_length: int,
+ no_boseos_middle: Optional[bool] = True,
+ clip_skip: Optional[int] = None,
+):
+ """
+ When the length of tokens is a multiple of the capacity of the text encoder,
+ it should be split into chunks and sent to the text encoder individually.
+ """
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
+ if max_embeddings_multiples > 1:
+ text_embeddings = []
+ hidden_states_all = []
+ for i in range(max_embeddings_multiples):
+ # extract the i-th chunk
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
+
+ # cover the head and the tail by the starting and the ending tokens
+ text_input_chunk[:, 0] = text_input[0, 0]
+ text_input_chunk[:, -1] = text_input[0, -1]
+ prompt_embeds = text_encoder(text_input_chunk,output_hidden_states=True)
+
+ text_embedding = prompt_embeds[0]
+ if clip_skip is None:
+ hidden_states = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ hidden_states = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ if no_boseos_middle:
+ if i == 0:
+ # discard the ending token
+ text_embedding = text_embedding[:, :-1]
+ hidden_states = hidden_states[:, :-1]
+ elif i == max_embeddings_multiples - 1:
+ # discard the starting token
+ text_embedding = text_embedding[:, 1:]
+ hidden_states = hidden_states[:, 1:]
+ else:
+ # discard both starting and ending tokens
+ text_embedding = text_embedding[:, 1:-1]
+ hidden_states = hidden_states[:, 1:-1]
+
+ text_embeddings.append(text_embedding)
+ hidden_states_all.append(hidden_states)
+ text_embeddings = torch.concat(text_embeddings, axis=1)
+ hidden_states = torch.concat(hidden_states_all, axis=1)
+ else:
+ prompt_embeds = text_encoder(text_input,output_hidden_states=True)
+ text_embeddings = prompt_embeds[0]
+ if clip_skip is None:
+ hidden_states = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ hidden_states = prompt_embeds.hidden_states[-(clip_skip + 2)]
+ return text_embeddings, hidden_states
+
+
+
+def get_unweighted_text_embeddings(
+ text_encoder,
+ text_input: torch.Tensor,
+ chunk_length: int,
+ no_boseos_middle: Optional[bool] = True,
+):
+ """
+ When the length of tokens is a multiple of the capacity of the text encoder,
+ it should be split into chunks and sent to the text encoder individually.
+ """
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
+ if max_embeddings_multiples > 1:
+ text_embeddings = []
+ for i in range(max_embeddings_multiples):
+ # extract the i-th chunk
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
+
+ # cover the head and the tail by the starting and the ending tokens
+ text_input_chunk[:, 0] = text_input[0, 0]
+ text_input_chunk[:, -1] = text_input[0, -1]
+ text_embedding = text_encoder(text_input_chunk)[0]
+
+ if no_boseos_middle:
+ if i == 0:
+ # discard the ending token
+ text_embedding = text_embedding[:, :-1]
+ elif i == max_embeddings_multiples - 1:
+ # discard the starting token
+ text_embedding = text_embedding[:, 1:]
+ else:
+ # discard both starting and ending tokens
+ text_embedding = text_embedding[:, 1:-1]
+
+ text_embeddings.append(text_embedding)
+ text_embeddings = torch.concat(text_embeddings, axis=1)
+ else:
+ text_embeddings = text_encoder(text_input)[0]
+ return text_embeddings
+
+
+def get_text_index(
+ tokenizer,
+ prompt: Union[str, List[str]],
+ max_embeddings_multiples: Optional[int] = 4,
+ no_boseos_middle: Optional[bool] = False,
+):
+ r"""
+ Prompts can be assigned with local weights using brackets. For example,
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
+
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
+
+ Args:
+ pipe (`DiffusionPipeline`):
+ Pipe to provide access to the tokenizer and the text encoder.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
+ ending token in each of the chunk in the middle.
+ """
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+ if isinstance(prompt, str):
+ prompt = [prompt]
+
+ prompt_tokens = [
+ token[1:-1] for token in tokenizer(prompt, max_length=max_length, truncation=True).input_ids
+ ]
+
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
+ max_length = max([len(token) for token in prompt_tokens])
+
+ max_embeddings_multiples = min(
+ max_embeddings_multiples,
+ (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
+ )
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+
+ # pad the length of tokens and weights
+ bos = tokenizer.bos_token_id
+ eos = tokenizer.eos_token_id
+ pad = getattr(tokenizer, "pad_token_id", eos)
+ prompt_tokens = pad_tokens(
+ prompt_tokens,
+ max_length,
+ bos,
+ eos,
+ pad,
+ no_boseos_middle=no_boseos_middle,
+ chunk_length=tokenizer.model_max_length,
+ )
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long)
+
+ return prompt_tokens
+
diff --git a/pip.sh b/pip.sh
new file mode 100644
index 000000000000..39fcc39d576a
--- /dev/null
+++ b/pip.sh
@@ -0,0 +1,3 @@
+pip install datasets --index-url https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
+pip install . --index-url https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
+pip install albumentations --index-url https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
\ No newline at end of file
diff --git a/pipeline_stable_diffusion_inpaint_lpw.py b/pipeline_stable_diffusion_inpaint_lpw.py
new file mode 100644
index 000000000000..bd31f597fe71
--- /dev/null
+++ b/pipeline_stable_diffusion_inpaint_lpw.py
@@ -0,0 +1,1306 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+import re
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+re_attention = re.compile(
+ r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""",
+ re.X,
+)
+
+
+def parse_prompt_attention(text):
+ """
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
+ Accepted tokens are:
+ (abc) - increases attention to abc by a multiplier of 1.1
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
+ [abc] - decreases attention to abc by a multiplier of 1.1
+ \( - literal character '('
+ \[ - literal character '['
+ \) - literal character ')'
+ \] - literal character ']'
+ \\ - literal character '\'
+ anything else - just text
+ >>> parse_prompt_attention('normal text')
+ [['normal text', 1.0]]
+ >>> parse_prompt_attention('an (important) word')
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+ >>> parse_prompt_attention('(unbalanced')
+ [['unbalanced', 1.1]]
+ >>> parse_prompt_attention('\(literal\]')
+ [['(literal]', 1.0]]
+ >>> parse_prompt_attention('(unnecessary)(parens)')
+ [['unnecessaryparens', 1.1]]
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+ [['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]]
+ """
+
+ res = []
+ round_brackets = []
+ square_brackets = []
+
+ round_bracket_multiplier = 1.1
+ square_bracket_multiplier = 1 / 1.1
+
+ def multiply_range(start_position, multiplier):
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+
+ for m in re_attention.finditer(text):
+ text = m.group(0)
+ weight = m.group(1)
+
+ if text.startswith("\\"):
+ res.append([text[1:], 1.0])
+ elif text == "(":
+ round_brackets.append(len(res))
+ elif text == "[":
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif text == ")" and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif text == "]" and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ res.append([text, 1.0])
+
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+
+ if len(res) == 0:
+ res = [["", 1.0]]
+
+ # merge runs of identical weights
+ i = 0
+ while i + 1 < len(res):
+ if res[i][1] == res[i + 1][1]:
+ res[i][0] += res[i + 1][0]
+ res.pop(i + 1)
+ else:
+ i += 1
+
+ return res
+
+
+def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
+ r"""
+ Tokenize a list of prompts and return its tokens with weights of each token.
+
+ No padding, starting or ending token is included.
+ """
+ tokens = []
+ weights = []
+ truncated = False
+ for text in prompt:
+ texts_and_weights = parse_prompt_attention(text)
+ text_token = []
+ text_weight = []
+ for word, weight in texts_and_weights:
+ # tokenize and discard the starting and the ending token
+ token = pipe.tokenizer(word).input_ids[1:-1]
+ text_token += token
+ # copy the weight by length of token
+ text_weight += [weight] * len(token)
+ # stop if the text is too long (longer than truncation limit)
+ if len(text_token) > max_length:
+ truncated = True
+ break
+ # truncate
+ if len(text_token) > max_length:
+ truncated = True
+ text_token = text_token[:max_length]
+ text_weight = text_weight[:max_length]
+ tokens.append(text_token)
+ weights.append(text_weight)
+ if truncated:
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
+ return tokens, weights
+
+
+def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
+ r"""
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
+ """
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
+ for i in range(len(tokens)):
+ tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
+ if no_boseos_middle:
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
+ else:
+ w = []
+ if len(weights[i]) == 0:
+ w = [1.0] * weights_length
+ else:
+ for j in range(max_embeddings_multiples):
+ w.append(1.0) # weight for starting token in this chunk
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
+ w.append(1.0) # weight for ending token in this chunk
+ w += [1.0] * (weights_length - len(w))
+ weights[i] = w[:]
+
+ return tokens, weights
+
+
+def get_unweighted_text_embeddings(
+ pipe: DiffusionPipeline,
+ text_input: torch.Tensor,
+ chunk_length: int,
+ no_boseos_middle: Optional[bool] = True,
+):
+ """
+ When the length of tokens is a multiple of the capacity of the text encoder,
+ it should be split into chunks and sent to the text encoder individually.
+ """
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
+ if max_embeddings_multiples > 1:
+ text_embeddings = []
+ for i in range(max_embeddings_multiples):
+ # extract the i-th chunk
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
+
+ # cover the head and the tail by the starting and the ending tokens
+ text_input_chunk[:, 0] = text_input[0, 0]
+ text_input_chunk[:, -1] = text_input[0, -1]
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
+
+ if no_boseos_middle:
+ if i == 0:
+ # discard the ending token
+ text_embedding = text_embedding[:, :-1]
+ elif i == max_embeddings_multiples - 1:
+ # discard the starting token
+ text_embedding = text_embedding[:, 1:]
+ else:
+ # discard both starting and ending tokens
+ text_embedding = text_embedding[:, 1:-1]
+
+ text_embeddings.append(text_embedding)
+ text_embeddings = torch.concat(text_embeddings, axis=1)
+ else:
+ text_embeddings = pipe.text_encoder(text_input)[0]
+ return text_embeddings
+
+
+def get_weighted_text_embeddings(
+ pipe: DiffusionPipeline,
+ prompt: Union[str, List[str]],
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
+ max_embeddings_multiples: Optional[int] = 3,
+ no_boseos_middle: Optional[bool] = False,
+ skip_parsing: Optional[bool] = False,
+ skip_weighting: Optional[bool] = False,
+):
+ r"""
+ Prompts can be assigned with local weights using brackets. For example,
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
+
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
+
+ Args:
+ pipe (`DiffusionPipeline`):
+ Pipe to provide access to the tokenizer and the text encoder.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ uncond_prompt (`str` or `List[str]`):
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
+ ending token in each of the chunk in the middle.
+ skip_parsing (`bool`, *optional*, defaults to `False`):
+ Skip the parsing of brackets.
+ skip_weighting (`bool`, *optional*, defaults to `False`):
+ Skip the weighting. When the parsing is skipped, it is forced True.
+ """
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+ if isinstance(prompt, str):
+ prompt = [prompt]
+
+ if not skip_parsing:
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
+ if uncond_prompt is not None:
+ if isinstance(uncond_prompt, str):
+ uncond_prompt = [uncond_prompt]
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
+ else:
+ prompt_tokens = [
+ token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
+ ]
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
+ if uncond_prompt is not None:
+ if isinstance(uncond_prompt, str):
+ uncond_prompt = [uncond_prompt]
+ uncond_tokens = [
+ token[1:-1]
+ for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
+ ]
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
+
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
+ max_length = max([len(token) for token in prompt_tokens])
+ if uncond_prompt is not None:
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
+
+ max_embeddings_multiples = min(
+ max_embeddings_multiples,
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
+ )
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
+
+ # pad the length of tokens and weights
+ bos = pipe.tokenizer.bos_token_id
+ eos = pipe.tokenizer.eos_token_id
+ pad = getattr(pipe.tokenizer, "pad_token_id", eos)
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
+ prompt_tokens,
+ prompt_weights,
+ max_length,
+ bos,
+ eos,
+ pad,
+ no_boseos_middle=no_boseos_middle,
+ chunk_length=pipe.tokenizer.model_max_length,
+ )
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
+ if uncond_prompt is not None:
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
+ uncond_tokens,
+ uncond_weights,
+ max_length,
+ bos,
+ eos,
+ pad,
+ no_boseos_middle=no_boseos_middle,
+ chunk_length=pipe.tokenizer.model_max_length,
+ )
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
+
+ # get the embeddings
+ text_embeddings = get_unweighted_text_embeddings(
+ pipe,
+ prompt_tokens,
+ pipe.tokenizer.model_max_length,
+ no_boseos_middle=no_boseos_middle,
+ )
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device)
+ if uncond_prompt is not None:
+ uncond_embeddings = get_unweighted_text_embeddings(
+ pipe,
+ uncond_tokens,
+ pipe.tokenizer.model_max_length,
+ no_boseos_middle=no_boseos_middle,
+ )
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device)
+
+ # assign weights to the prompts and normalize in the sense of mean
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
+ if (not skip_parsing) and (not skip_weighting):
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+ text_embeddings *= prompt_weights.unsqueeze(-1)
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+ if uncond_prompt is not None:
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
+
+ if uncond_prompt is not None:
+ return text_embeddings, uncond_embeddings
+ return text_embeddings, None
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+ deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
+ deprecate(
+ "prepare_mask_and_masked_image",
+ "0.30.0",
+ deprecation_message,
+ )
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
+
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+class StableDiffusionInpaintPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+
+ Args:
+ vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+ model_cpu_offload_seq = "text_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+
+ def __init__(
+ self,
+ vae: Union[AutoencoderKL, AsymmetricAutoencoderKL],
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration"
+ " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
+ " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
+ " Hub, it would be very nice if you could open a Pull request for the"
+ " `scheduler/scheduler_config.json` file"
+ )
+ deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["skip_prk_steps"] = True
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
+ if unet.config.in_channels != 9:
+ logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.")
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ max_embeddings_multiples=3,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `list(int)`):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if negative_prompt_embeds is None:
+ if negative_prompt is None:
+ negative_prompt = [""] * batch_size
+ elif isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * batch_size
+ if batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ if prompt_embeds is None or negative_prompt_embeds is None:
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer)
+
+ prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings(
+ pipe=self,
+ prompt=prompt,
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
+ max_embeddings_multiples=max_embeddings_multiples,
+ )
+ if prompt_embeds is None:
+ prompt_embeds = prompt_embeds1
+ if negative_prompt_embeds is None:
+ negative_prompt_embeds = negative_prompt_embeds1
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+
+ if image.shape[1] == 4:
+ image_latents = image
+ else:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ else:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: torch.FloatTensor = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 1.0,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ clip_skip: int = None,
+ max_embeddings_multiples: Optional[int] = 4,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to
+ be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch
+ tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the
+ expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the
+ expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but
+ if passing latents directly it is not encoded again.
+ mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ Examples:
+
+ ```py
+ >>> import PIL
+ >>> import requests
+ >>> import torch
+ >>> from io import BytesIO
+
+ >>> from diffusers import StableDiffusionInpaintPipeline
+
+
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> init_image = download_image(img_url).resize((512, 512))
+ >>> mask_image = download_image(mask_url).resize((512, 512))
+
+ >>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+ >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+
+ # # For classifier free guidance, we need to do two forward passes.
+ # # Here we concatenate the unconditional and text embeddings into a single batch
+ # # to avoid doing two forward passes
+ # if do_classifier_free_guidance:
+ # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ max_embeddings_multiples,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+ # 4. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps=num_inference_steps, strength=strength, device=device
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)
+
+ if masked_image_latents is None:
+ masked_image = init_image * (mask_condition < 0.5)
+ else:
+ masked_image = masked_image_latents
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 10. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents
+ if do_classifier_free_guidance:
+ init_mask, _ = mask.chunk(2)
+ else:
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ condition_kwargs = {}
+ if isinstance(self.vae, AsymmetricAutoencoderKL):
+ init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
+ init_image_condition = init_image.clone()
+ init_image = self._encode_vae_image(init_image, generator=generator)
+ mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
+ condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/pipeline_stable_diffusion_xl_inpaint_lpw.py b/pipeline_stable_diffusion_xl_inpaint_lpw.py
new file mode 100644
index 000000000000..4f6653ec9a1d
--- /dev/null
+++ b/pipeline_stable_diffusion_xl_inpaint_lpw.py
@@ -0,0 +1,1416 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_invisible_watermark_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+
+if is_invisible_watermark_available():
+ from .watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+from llmga.diffusers.my_utils.util import get_unweighted_text_embeddings_SDXL1, get_unweighted_text_embeddings_SDXL2, get_text_index
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
+ ... torch_dtype=torch.float16,
+ ... variant="fp16",
+ ... use_safetensors=True,
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> init_image = load_image(img_url).convert("RGB")
+ >>> mask_image = load_image(mask_url).convert("RGB")
+
+ >>> prompt = "A majestic tiger sitting on a bench"
+ >>> image = pipe(
+ ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
+ ... ).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+def mask_pil_to_torch(mask, height, width):
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask = torch.from_numpy(mask)
+ return mask
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ # checkpoint. TOD(Yiyi) - need to clean this up later
+ deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
+ deprecate(
+ "prepare_mask_and_masked_image",
+ "0.30.0",
+ deprecation_message,
+ )
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ mask = mask_pil_to_torch(mask, height, width)
+
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ # if image.min() < -1 or image.max() > 1:
+ # raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ mask = mask_pil_to_torch(mask, height, width)
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ if image.shape[1] == 4:
+ # images are in latent space and thus can't
+ # be masked set masked_image to None
+ # we assume that the checkpoint is not an inpainting
+ # checkpoint. TOD(Yiyi) - need to clean this up later
+ masked_image = None
+ else:
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+class StableDiffusionXLInpaintPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
+
+ _optional_components = ["tokenizer", "text_encoder"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ requires_aesthetics_score: bool = False,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+ )
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ fg=0
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_input_ids = get_text_index(tokenizer,prompt)
+
+
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if fg==0:
+ text_embeddings, hidden_states = get_unweighted_text_embeddings_SDXL1(text_encoder,text_input_ids.to(device),chunk_length=tokenizer.model_max_length,clip_skip=clip_skip)
+ fg=1
+ else:
+ text_embeddings, hidden_states = get_unweighted_text_embeddings_SDXL2(text_encoder,text_input_ids.to(device),chunk_length=tokenizer.model_max_length,clip_skip=clip_skip)
+
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = text_embeddings
+
+
+ prompt_embeds_list.append(hidden_states)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ fg=1
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+
+ uncond_input = get_text_index(tokenizer,negative_prompt)
+
+ if fg==0:
+ negative_pooled_prompt_embeds, negative_prompt_embeds = get_unweighted_text_embeddings_SDXL1(text_encoder,uncond_input.to(device),chunk_length=tokenizer.model_max_length,clip_skip=clip_skip)
+ fg=1
+ else:
+ negative_pooled_prompt_embeds, negative_prompt_embeds = get_unweighted_text_embeddings_SDXL2(text_encoder,uncond_input.to(device),chunk_length=tokenizer.model_max_length,clip_skip=clip_skip)
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder)
+ unscale_lora_layers(self.text_encoder_2)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ add_noise=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if image.shape[1] == 4:
+ image_latents = image.to(device=device, dtype=dtype)
+ elif return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None and add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+ elif add_noise:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = image_latents.to(device)
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ dtype = image.dtype
+ if self.vae.config.force_upcast:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+
+ if self.vae.config.force_upcast:
+ self.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+ if masked_image is not None and masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = None
+
+ if masked_image is not None:
+ if masked_image_latents is None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+ else:
+ t_start = 0
+
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ # Strength is irrelevant if we directly request a timestep to start at;
+ # that is, strength is determined by the denoising_start instead.
+ if denoising_start is not None:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
+ return torch.tensor(timesteps), len(timesteps)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+ def _get_add_time_ids(
+ self,
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype,
+ ):
+ if self.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+ )
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: torch.FloatTensor = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.9999,
+ num_inference_steps: int = 50,
+ denoising_start: Optional[float] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ aesthetic_score: float = 6.0,
+ negative_aesthetic_score: float = 2.5,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+ be masked out with `mask_image` and repainted according to `prompt`.
+ mask_image (`PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ strength (`float`, *optional*, defaults to 0.9999):
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+ portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
+ integer, the value of `strength` will be ignored.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_start (`float`, *optional*):
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=clip_skip,
+ )
+
+ # 4. set timesteps
+ def denoising_value_valid(dnv):
+ return isinstance(denoising_end, float) and 0 < dnv < 1
+
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid else None
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
+ init_image = init_image.to(dtype=torch.float32)
+
+ mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
+
+ if masked_image_latents is not None:
+ masked_image = masked_image_latents
+ elif init_image.shape[1] == 4:
+ # if images are in latent space, we can't mask it
+ masked_image = None
+ else:
+ masked_image = init_image * (mask < 0.5)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ add_noise = True if denoising_start is None else False
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ add_noise=add_noise,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+ # 8.1 Prepare extra step kwargs.
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ height, width = latents.shape[-2:]
+ height = height * self.vae_scale_factor
+ width = width * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 10. Prepare added time ids & embeddings
+ if negative_original_size is None:
+ negative_original_size = original_size
+ if negative_target_size is None:
+ negative_target_size = target_size
+
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ )
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device)
+
+ # 11. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ if (
+ denoising_end is not None
+ and denoising_start is not None
+ and denoising_value_valid(denoising_end)
+ and denoising_value_valid(denoising_start)
+ and denoising_start >= denoising_end
+ ):
+ raise ValueError(
+ f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {denoising_end} when using type float."
+ )
+ elif denoising_end is not None and denoising_value_valid(denoising_end):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents
+ if do_classifier_free_guidance:
+ init_mask, _ = mask.chunk(2)
+ else:
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ return StableDiffusionXLPipelineOutput(images=latents)
+
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/pipeline_stable_diffusion_xl_lpw.py b/pipeline_stable_diffusion_xl_lpw.py
new file mode 100644
index 000000000000..a5c451baa081
--- /dev/null
+++ b/pipeline_stable_diffusion_xl_lpw.py
@@ -0,0 +1,944 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_invisible_watermark_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+if is_invisible_watermark_available():
+ from .watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+import re
+from llmga.diffusers.my_utils.util import get_unweighted_text_embeddings_SDXL1, get_unweighted_text_embeddings_SDXL2, get_text_index
+
+
+logger = logging.get_logger(__name__) # : disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLPipeline
+
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+class StableDiffusionXLPipeline(
+ DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
+
+ as well as the following saving methods:
+ - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.default_sample_size = self.unet.config.sample_size
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ fg=0
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_input_ids = get_text_index(tokenizer,prompt)
+
+
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if fg==0:
+ text_embeddings, hidden_states = get_unweighted_text_embeddings_SDXL1(text_encoder,text_input_ids.to(device),chunk_length=tokenizer.model_max_length,clip_skip=clip_skip)
+ fg=1
+ else:
+ text_embeddings, hidden_states = get_unweighted_text_embeddings_SDXL2(text_encoder,text_input_ids.to(device),chunk_length=tokenizer.model_max_length,clip_skip=clip_skip)
+
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = text_embeddings
+
+
+ prompt_embeds_list.append(hidden_states)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ fg=1
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+
+ uncond_input = get_text_index(tokenizer,negative_prompt)
+
+ if fg==0:
+ negative_pooled_prompt_embeds, negative_prompt_embeds = get_unweighted_text_embeddings_SDXL1(text_encoder,uncond_input.to(device),chunk_length=tokenizer.model_max_length,clip_skip=clip_skip)
+ fg=1
+ else:
+ negative_pooled_prompt_embeds, negative_prompt_embeds = get_unweighted_text_embeddings_SDXL2(text_encoder,uncond_input.to(device),chunk_length=tokenizer.model_max_length,clip_skip=clip_skip)
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder)
+ unscale_lora_layers(self.text_encoder_2)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=clip_skip,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ add_time_ids = self._get_add_time_ids(
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
+ )
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 8.1 Apply denoising_end
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/pyproject.toml b/pyproject.toml
index 0612f2f9e059..a5fe70af9ca7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,10 @@
+[tool.black]
+line-length = 119
+target-version = ['py37']
+
[tool.ruff]
# Never enforce `E501` (line length violations).
-ignore = ["C901", "E501", "E741", "F402", "F823"]
+ignore = ["C901", "E501", "E741", "W605"]
select = ["C", "E", "F", "I", "W"]
line-length = 119
@@ -12,16 +16,3 @@ line-length = 119
[tool.ruff.isort]
lines-after-imports = 2
known-first-party = ["diffusers"]
-
-[tool.ruff.format]
-# Like Black, use double quotes for strings.
-quote-style = "double"
-
-# Like Black, indent with spaces, rather than tabs.
-indent-style = "space"
-
-# Like Black, respect magic trailing commas.
-skip-magic-trailing-comma = false
-
-# Like Black, automatically detect the appropriate line ending.
-line-ending = "auto"
diff --git a/scripts/convert_kakao_brain_unclip_to_diffusers.py b/scripts/convert_kakao_brain_unclip_to_diffusers.py
index b02cb498bb9b..85d983dea686 100644
--- a/scripts/convert_kakao_brain_unclip_to_diffusers.py
+++ b/scripts/convert_kakao_brain_unclip_to_diffusers.py
@@ -11,7 +11,7 @@
from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
-r"""
+"""
Example - From the diffusers root directory:
Download weights:
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 000000000000..fe555d61c69a
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,20 @@
+[isort]
+default_section = FIRSTPARTY
+ensure_newline_before_comments = True
+force_grid_wrap = 0
+include_trailing_comma = True
+known_first_party = accelerate
+known_third_party =
+ numpy
+ torch
+ torch_xla
+
+line_length = 119
+lines_after_imports = 2
+multi_line_output = 3
+use_parentheses = True
+
+[flake8]
+ignore = E203, E722, E501, E741, W503, W605
+max-line-length = 119
+per-file-ignores = __init__.py:F401
diff --git a/setup.py b/setup.py
index 3ec9a54f8215..7ad5646d4fca 100644
--- a/setup.py
+++ b/setup.py
@@ -15,12 +15,12 @@
"""
Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/main/setup.py
-To create the package for PyPI.
+To create the package for pypi.
1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
documentation.
- If releasing on a special branch, copy the updated README.md on the main branch for the commit you will make
+ If releasing on a special branch, copy the updated README.md on the main branch for your the commit you will make
for the post-release and run `make fix-copies` on the main branch as well.
2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid.
@@ -30,29 +30,29 @@
4. Checkout the release branch (v-release, for example v4.19-release), and commit these changes with the
message: "Release: " and push.
-5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs).
+5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs)
-6. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for PyPI'"
+6. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for pypi' "
Push the tag to git: git push --tags origin v-release
7. Build both the sources and the wheel. Do not change anything in setup.py between
creating the wheel and the source distribution (obviously).
- For the wheel, run: "python setup.py bdist_wheel" in the top level directory
- (This will build a wheel for the Python version you use to build it).
+ For the wheel, run: "python setup.py bdist_wheel" in the top level directory.
+ (this will build a wheel for the python version you use to build it).
For the sources, run: "python setup.py sdist"
You should now have a /dist directory with both .whl and .tar.gz source versions.
- Long story cut short, you need to run both before you can upload the distribution to the
- test PyPI and the actual PyPI servers:
-
+ Long story cut short, you need to run both before you can upload the distribution to the
+ test pypi and the actual pypi servers:
+
python setup.py bdist_wheel && python setup.py sdist
-8. Check that everything looks correct by uploading the package to the PyPI test server:
+8. Check that everything looks correct by uploading the package to the pypi test server:
twine upload dist/* -r pypitest
- (pypi suggests using twine as other methods upload files via plaintext.)
+ (pypi suggest using twine as other methods upload files via plaintext.)
You may have to specify the repository url, use the following command then:
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
@@ -64,15 +64,15 @@
pip install -i https://testpypi.python.org/pypi diffusers
Check you can run the following commands:
- python -c "from diffusers import __version__; print(__version__)"
+ python -c "python -c "from diffusers import __version__; print(__version__)"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
python -c "from diffusers import *"
-9. Upload the final version to the actual PyPI:
+9. Upload the final version to actual pypi:
twine upload dist/* -r pypi
-10. Prepare the release notes and publish them on GitHub once everything is looking hunky-dory.
+10. Prepare the release notes and publish them on github once everything is looking hunky-dory.
11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release,
you need to go back to main before executing this.
@@ -80,7 +80,6 @@
import os
import re
-import sys
from distutils.core import Command
from setuptools import find_packages, setup
@@ -93,11 +92,12 @@
"Pillow", # keep the PIL.Image.Resampling deprecation away
"accelerate>=0.11.0",
"compel==0.1.8",
+ "black~=23.1",
"datasets",
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
- "huggingface-hub>=0.19.4",
+ "huggingface-hub>=0.13.2",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
@@ -112,13 +112,11 @@
"numpy",
"omegaconf",
"parameterized",
- "peft>=0.6.0",
"protobuf>=3.20.3,<4",
"pytest",
"pytest-timeout",
"pytest-xdist",
- "python>=3.8.0",
- "ruff>=0.1.5,<=0.2",
+ "ruff==0.0.280",
"safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92",
"scipy",
@@ -144,7 +142,7 @@
# anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with:
#
# python -c 'import sys; from diffusers.dependency_versions_table import deps; \
-# print(" ".join([deps[x] for x in sys.argv[1:]]))' tokenizers datasets
+# print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets
#
# Just pass the desired package names to that script as it's shown with 2 packages above.
#
@@ -153,7 +151,7 @@
# You can then feed this for example to `pip`:
#
# pip install -U $(python -c 'import sys; from diffusers.dependency_versions_table import deps; \
-# print(" ".join([deps[x] for x in sys.argv[1:]]))' tokenizers datasets)
+# print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets)
#
@@ -170,11 +168,7 @@ class DepsTableUpdateCommand(Command):
description = "build runtime dependency table"
user_options = [
# format: (long option, short option, description).
- (
- "dep-table-update",
- None,
- "updates src/diffusers/dependency_versions_table.py",
- ),
+ ("dep-table-update", None, "updates src/diffusers/dependency_versions_table.py"),
]
def initialize_options(self):
@@ -188,7 +182,7 @@ def run(self):
content = [
"# THIS FILE HAS BEEN AUTOGENERATED. To update:",
"# 1. modify the `_deps` dict in setup.py",
- "# 2. run `make deps_table_update`",
+ "# 2. run `make deps_table_update``",
"deps = {",
entries,
"}",
@@ -201,7 +195,10 @@ def run(self):
extras = {}
-extras["quality"] = deps_list("urllib3", "isort", "ruff", "hf-doc-builder")
+
+
+extras = {}
+extras["quality"] = deps_list("urllib3", "black", "isort", "ruff", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2")
extras["test"] = deps_list(
@@ -245,11 +242,9 @@ def run(self):
deps["Pillow"],
]
-version_range_max = max(sys.version_info[1], 10) + 1
-
setup(
name="diffusers",
- version="0.24.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="0.22.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
@@ -273,30 +268,30 @@ def run(self):
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3",
- ]
- + [f"Programming Language :: Python :: 3.{i}" for i in range(8, version_range_max)],
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ ],
cmdclass={"deps_table_update": DepsTableUpdateCommand},
)
-
# Release checklist
# 1. Change the version in __init__.py and setup.py.
# 2. Commit these changes with the message: "Release: Release"
-# 3. Add a tag in git to mark the release: "git tag RELEASE -m 'Adds tag RELEASE for PyPI'"
+# 3. Add a tag in git to mark the release: "git tag RELEASE -m 'Adds tag RELEASE for pypi' "
# Push the tag to git: git push --tags origin main
# 4. Run the following commands in the top-level directory:
# python setup.py bdist_wheel
# python setup.py sdist
-# 5. Upload the package to the PyPI test server first:
+# 5. Upload the package to the pypi test server first:
# twine upload dist/* -r pypitest
# twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
# 6. Check that you can install it in a virtualenv by running:
# pip install -i https://testpypi.python.org/pypi diffusers
# diffusers env
# diffusers test
-# 7. Upload the final version to the actual PyPI:
+# 7. Upload the final version to actual pypi:
# twine upload dist/* -r pypi
-# 8. Add release notes to the tag in GitHub once everything is looking hunky-dory.
-# 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to main.
+# 8. Add release notes to the tag in github once everything is looking hunky-dory.
+# 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to master
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 209508cf1bac..42f352c029c8 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.24.0"
+__version__ = "0.22.0.dev0"
from typing import TYPE_CHECKING
@@ -76,13 +76,9 @@
[
"AsymmetricAutoencoderKL",
"AutoencoderKL",
- "AutoencoderKLTemporalDecoder",
"AutoencoderTiny",
- "ConsistencyDecoderVAE",
"ControlNetModel",
- "Kandinsky3UNet",
"ModelMixin",
- "MotionAdapter",
"MultiAdapter",
"PriorTransformer",
"T2IAdapter",
@@ -92,12 +88,9 @@
"UNet2DConditionModel",
"UNet2DModel",
"UNet3DConditionModel",
- "UNetMotionModel",
- "UNetSpatioTemporalConditionModel",
"VQModel",
]
)
-
_import_structure["optimization"] = [
"get_constant_schedule",
"get_constant_schedule_with_warmup",
@@ -107,6 +100,7 @@
"get_polynomial_decay_schedule_with_warmup",
"get_scheduler",
]
+
_import_structure["pipelines"].extend(
[
"AudioPipelineOutput",
@@ -148,7 +142,6 @@
"KarrasVeScheduler",
"KDPM2AncestralDiscreteScheduler",
"KDPM2DiscreteScheduler",
- "LCMScheduler",
"PNDMScheduler",
"RePaintScheduler",
"SchedulerMixin",
@@ -201,7 +194,6 @@
[
"AltDiffusionImg2ImgPipeline",
"AltDiffusionPipeline",
- "AnimateDiffPipeline",
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
@@ -217,8 +209,6 @@
"IFPipeline",
"IFSuperResolutionPipeline",
"ImageTextPipelineOutput",
- "Kandinsky3Img2ImgPipeline",
- "Kandinsky3Pipeline",
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline",
@@ -236,12 +226,9 @@
"KandinskyV22Pipeline",
"KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline",
- "LatentConsistencyModelImg2ImgPipeline",
- "LatentConsistencyModelPipeline",
"LDMTextToImagePipeline",
"MusicLDMPipeline",
"PaintByExamplePipeline",
- "PixArtAlphaPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -279,10 +266,8 @@
"StableDiffusionXLPipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
- "StableVideoDiffusionPipeline",
"TextToVideoSDPipeline",
"TextToVideoZeroPipeline",
- "TextToVideoZeroSDXLPipeline",
"UnCLIPImageVariationPipeline",
"UnCLIPPipeline",
"UniDiffuserModel",
@@ -450,13 +435,9 @@
from .models import (
AsymmetricAutoencoderKL,
AutoencoderKL,
- AutoencoderKLTemporalDecoder,
AutoencoderTiny,
- ConsistencyDecoderVAE,
ControlNetModel,
- Kandinsky3UNet,
ModelMixin,
- MotionAdapter,
MultiAdapter,
PriorTransformer,
T2IAdapter,
@@ -466,8 +447,6 @@
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
- UNetMotionModel,
- UNetSpatioTemporalConditionModel,
VQModel,
)
from .optimization import (
@@ -520,7 +499,6 @@
KarrasVeScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
- LCMScheduler,
PNDMScheduler,
RePaintScheduler,
SchedulerMixin,
@@ -556,7 +534,6 @@
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,
- AnimateDiffPipeline,
AudioLDM2Pipeline,
AudioLDM2ProjectionModel,
AudioLDM2UNet2DConditionModel,
@@ -570,8 +547,6 @@
IFPipeline,
IFSuperResolutionPipeline,
ImageTextPipelineOutput,
- Kandinsky3Img2ImgPipeline,
- Kandinsky3Pipeline,
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline,
@@ -589,12 +564,9 @@
KandinskyV22Pipeline,
KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline,
- LatentConsistencyModelImg2ImgPipeline,
- LatentConsistencyModelPipeline,
LDMTextToImagePipeline,
MusicLDMPipeline,
PaintByExamplePipeline,
- PixArtAlphaPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
@@ -632,10 +604,8 @@
StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
- StableVideoDiffusionPipeline,
TextToVideoSDPipeline,
TextToVideoZeroPipeline,
- TextToVideoZeroSDXLPipeline,
UnCLIPImageVariationPipeline,
UnCLIPPipeline,
UniDiffuserModel,
diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
index 1b91bfda3058..9bc25155a0b6 100644
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -95,7 +95,6 @@ class ConfigMixin:
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
subclass).
"""
-
config_name = None
ignore_for_config = []
has_compatibles = False
@@ -486,18 +485,10 @@ def extract_init_dict(cls, config_dict, **kwargs):
# remove attributes from orig class that cannot be expected
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
- if (
- isinstance(orig_cls_name, str)
- and orig_cls_name != cls.__name__
- and hasattr(diffusers_library, orig_cls_name)
- ):
+ if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
orig_cls = getattr(diffusers_library, orig_cls_name)
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
- elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
- raise ValueError(
- "Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
- )
# remove private attributes
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
diff --git a/src/diffusers/dependency_versions_check.py b/src/diffusers/dependency_versions_check.py
index 0144db201aa1..4f8578c52957 100644
--- a/src/diffusers/dependency_versions_check.py
+++ b/src/diffusers/dependency_versions_check.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import sys
from .dependency_versions_table import deps
from .utils.versions import require_version, require_version_core
@@ -22,9 +23,21 @@
# order specific notes:
# - tqdm must be checked before tokenizers
-pkgs_to_check_at_runtime = "python requests filelock numpy".split()
+pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
+if sys.version_info < (3, 7):
+ pkgs_to_check_at_runtime.append("dataclasses")
+if sys.version_info < (3, 8):
+ pkgs_to_check_at_runtime.append("importlib_metadata")
+
for pkg in pkgs_to_check_at_runtime:
if pkg in deps:
+ if pkg == "tokenizers":
+ # must be loaded here, or else tqdm check may fail
+ from .utils import is_tokenizers_available
+
+ if not is_tokenizers_available():
+ continue # not required, check version only if installed
+
require_version_core(deps[pkg])
else:
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index 7ec2e2cf6d5c..970013c31a20 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -1,15 +1,16 @@
# THIS FILE HAS BEEN AUTOGENERATED. To update:
# 1. modify the `_deps` dict in setup.py
-# 2. run `make deps_table_update`
+# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow",
"accelerate": "accelerate>=0.11.0",
"compel": "compel==0.1.8",
+ "black": "black~=23.1",
"datasets": "datasets",
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
- "huggingface-hub": "huggingface-hub>=0.19.4",
+ "huggingface-hub": "huggingface-hub>=0.13.2",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
@@ -24,13 +25,11 @@
"numpy": "numpy",
"omegaconf": "omegaconf",
"parameterized": "parameterized",
- "peft": "peft>=0.6.0",
"protobuf": "protobuf>=3.20.3,<4",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
- "python": "python>=3.8.0",
- "ruff": "ruff>=0.1.5,<=0.2",
+ "ruff": "ruff==0.0.280",
"safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"scipy": "scipy",
diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py
index a515805fd087..28a12f2d1364 100644
--- a/src/diffusers/image_processor.py
+++ b/src/diffusers/image_processor.py
@@ -13,7 +13,7 @@
# limitations under the License.
import warnings
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Union
import numpy as np
import PIL.Image
@@ -33,15 +33,6 @@
List[torch.FloatTensor],
]
-PipelineDepthInput = Union[
- PIL.Image.Image,
- np.ndarray,
- torch.FloatTensor,
- List[PIL.Image.Image],
- List[np.ndarray],
- List[torch.FloatTensor],
-]
-
class VaeImageProcessor(ConfigMixin):
"""
@@ -135,14 +126,14 @@ def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
return images
@staticmethod
- def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
+ def normalize(images):
"""
Normalize an image array to [-1,1].
"""
return 2.0 * images - 1.0
@staticmethod
- def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
+ def denormalize(images):
"""
Denormalize an image array to [0,1].
"""
@@ -168,10 +159,10 @@ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
def get_default_height_width(
self,
- image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
+ image: [PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
- ) -> Tuple[int, int]:
+ ):
"""
This function return the height and width that are downscaled to the next integer multiple of
`vae_scale_factor`.
@@ -211,24 +202,12 @@ def get_default_height_width(
def resize(
self,
- image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
+ image: [PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
- ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
+ ) -> [PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Resize image.
-
- Args:
- image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
- The image input, can be a PIL image, numpy array or pytorch tensor.
- height (`int`, *optional*, defaults to `None`):
- The height to resize to.
- width (`int`, *optional*`, defaults to `None`):
- The width to resize to.
-
- Returns:
- `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
- The resized image.
"""
if isinstance(image, PIL.Image.Image):
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
@@ -248,15 +227,7 @@ def resize(
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
"""
- Create a mask.
-
- Args:
- image (`PIL.Image.Image`):
- The image input, should be a PIL image.
-
- Returns:
- `PIL.Image.Image`:
- The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
+ create a mask
"""
image[image < 0.5] = 0
image[image >= 0.5] = 1
@@ -335,7 +306,7 @@ def preprocess(
# expected range [0,1], normalize to [-1,1]
do_normalize = self.config.do_normalize
- if do_normalize and image.min() < 0:
+ if image.min() < 0 and do_normalize:
warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
@@ -356,23 +327,7 @@ def postprocess(
image: torch.FloatTensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
- ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
- """
- Postprocess the image output from tensor to `output_type`.
-
- Args:
- image (`torch.FloatTensor`):
- The image input, should be a pytorch tensor with shape `B x C x H x W`.
- output_type (`str`, *optional*, defaults to `pil`):
- The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
- do_denormalize (`List[bool]`, *optional*, defaults to `None`):
- Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
- `VaeImageProcessor` config.
-
- Returns:
- `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
- The postprocessed image.
- """
+ ):
if not isinstance(image, torch.Tensor):
raise ValueError(
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
@@ -435,7 +390,7 @@ def __init__(
super().__init__()
@staticmethod
- def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
+ def numpy_to_pil(images):
"""
Convert a NumPy image or a batch of images to a PIL image.
"""
@@ -451,19 +406,7 @@ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
return pil_images
@staticmethod
- def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
- """
- Convert a PIL image or a list of PIL images to NumPy arrays.
- """
- if not isinstance(images, list):
- images = [images]
-
- images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
- images = np.stack(images, axis=0)
- return images
-
- @staticmethod
- def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
+ def rgblike_to_depthmap(image):
"""
Args:
image: RGB-like depth image
@@ -473,7 +416,7 @@ def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndar
"""
return image[:, :, 1] * 2**8 + image[:, :, 2]
- def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
+ def numpy_to_depth(self, images):
"""
Convert a NumPy depth image or a batch of images to a PIL image.
"""
@@ -498,23 +441,7 @@ def postprocess(
image: torch.FloatTensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
- ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
- """
- Postprocess the image output from tensor to `output_type`.
-
- Args:
- image (`torch.FloatTensor`):
- The image input, should be a pytorch tensor with shape `B x C x H x W`.
- output_type (`str`, *optional*, defaults to `pil`):
- The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
- do_denormalize (`List[bool]`, *optional*, defaults to `None`):
- Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
- `VaeImageProcessor` config.
-
- Returns:
- `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
- The postprocessed image.
- """
+ ):
if not isinstance(image, torch.Tensor):
raise ValueError(
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
@@ -547,102 +474,3 @@ def postprocess(
return self.numpy_to_pil(image), self.numpy_to_depth(image)
else:
raise Exception(f"This type {output_type} is not supported")
-
- def preprocess(
- self,
- rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
- depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
- height: Optional[int] = None,
- width: Optional[int] = None,
- target_res: Optional[int] = None,
- ) -> torch.Tensor:
- """
- Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
- """
- supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
-
- # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
- if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3:
- raise Exception("This is not yet supported")
-
- if isinstance(rgb, supported_formats):
- rgb = [rgb]
- depth = [depth]
- elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)):
- raise ValueError(
- f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
- )
-
- if isinstance(rgb[0], PIL.Image.Image):
- if self.config.do_convert_rgb:
- raise Exception("This is not yet supported")
- # rgb = [self.convert_to_rgb(i) for i in rgb]
- # depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
- if self.config.do_resize or target_res:
- height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res
- rgb = [self.resize(i, height, width) for i in rgb]
- depth = [self.resize(i, height, width) for i in depth]
- rgb = self.pil_to_numpy(rgb) # to np
- rgb = self.numpy_to_pt(rgb) # to pt
-
- depth = self.depth_pil_to_numpy(depth) # to np
- depth = self.numpy_to_pt(depth) # to pt
-
- elif isinstance(rgb[0], np.ndarray):
- rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0)
- rgb = self.numpy_to_pt(rgb)
- height, width = self.get_default_height_width(rgb, height, width)
- if self.config.do_resize:
- rgb = self.resize(rgb, height, width)
-
- depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0)
- depth = self.numpy_to_pt(depth)
- height, width = self.get_default_height_width(depth, height, width)
- if self.config.do_resize:
- depth = self.resize(depth, height, width)
-
- elif isinstance(rgb[0], torch.Tensor):
- raise Exception("This is not yet supported")
- # rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
-
- # if self.config.do_convert_grayscale and rgb.ndim == 3:
- # rgb = rgb.unsqueeze(1)
-
- # channel = rgb.shape[1]
-
- # height, width = self.get_default_height_width(rgb, height, width)
- # if self.config.do_resize:
- # rgb = self.resize(rgb, height, width)
-
- # depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
-
- # if self.config.do_convert_grayscale and depth.ndim == 3:
- # depth = depth.unsqueeze(1)
-
- # channel = depth.shape[1]
- # # don't need any preprocess if the image is latents
- # if depth == 4:
- # return rgb, depth
-
- # height, width = self.get_default_height_width(depth, height, width)
- # if self.config.do_resize:
- # depth = self.resize(depth, height, width)
- # expected range [0,1], normalize to [-1,1]
- do_normalize = self.config.do_normalize
- if rgb.min() < 0 and do_normalize:
- warnings.warn(
- "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
- f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
- FutureWarning,
- )
- do_normalize = False
-
- if do_normalize:
- rgb = self.normalize(rgb)
- depth = self.normalize(depth)
-
- if self.config.do_binarize:
- rgb = self.binarize(rgb)
- depth = self.binarize(depth)
-
- return rgb, depth
diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py
new file mode 100644
index 000000000000..695a22d955da
--- /dev/null
+++ b/src/diffusers/loaders.py
@@ -0,0 +1,3325 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import re
+from collections import defaultdict
+from contextlib import nullcontext
+from io import BytesIO
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Union
+
+import requests
+import safetensors
+import torch
+from huggingface_hub import hf_hub_download, model_info
+from packaging import version
+from torch import nn
+
+from . import __version__
+from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
+from .utils import (
+ DIFFUSERS_CACHE,
+ HF_HUB_OFFLINE,
+ USE_PEFT_BACKEND,
+ _get_model_file,
+ convert_state_dict_to_diffusers,
+ convert_state_dict_to_peft,
+ convert_unet_state_dict_to_peft,
+ deprecate,
+ get_adapter_name,
+ get_peft_kwargs,
+ is_accelerate_available,
+ is_omegaconf_available,
+ is_transformers_available,
+ logging,
+ recurse_remove_peft_layers,
+ scale_lora_layers,
+ set_adapter_layers,
+ set_weights_and_activate_adapters,
+)
+from .utils.import_utils import BACKENDS_MAPPING
+
+
+if is_transformers_available():
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel
+
+if is_accelerate_available():
+ from accelerate import init_empty_weights
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
+
+logger = logging.get_logger(__name__)
+
+TEXT_ENCODER_NAME = "text_encoder"
+UNET_NAME = "unet"
+
+LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
+LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
+
+TEXT_INVERSION_NAME = "learned_embeds.bin"
+TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
+
+CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
+CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
+
+LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
+
+
+class PatchedLoraProjection(nn.Module):
+ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
+ super().__init__()
+ from .models.lora import LoRALinearLayer
+
+ self.regular_linear_layer = regular_linear_layer
+
+ device = self.regular_linear_layer.weight.device
+
+ if dtype is None:
+ dtype = self.regular_linear_layer.weight.dtype
+
+ self.lora_linear_layer = LoRALinearLayer(
+ self.regular_linear_layer.in_features,
+ self.regular_linear_layer.out_features,
+ network_alpha=network_alpha,
+ device=device,
+ dtype=dtype,
+ rank=rank,
+ )
+
+ self.lora_scale = lora_scale
+
+ # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved
+ # when saving the whole text encoder model and when LoRA is unloaded or fused
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
+ if self.lora_linear_layer is None:
+ return self.regular_linear_layer.state_dict(
+ *args, destination=destination, prefix=prefix, keep_vars=keep_vars
+ )
+
+ return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
+
+ def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
+ if self.lora_linear_layer is None:
+ return
+
+ dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
+
+ w_orig = self.regular_linear_layer.weight.data.float()
+ w_up = self.lora_linear_layer.up.weight.data.float()
+ w_down = self.lora_linear_layer.down.weight.data.float()
+
+ if self.lora_linear_layer.network_alpha is not None:
+ w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
+
+ fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
+
+ if safe_fusing and torch.isnan(fused_weight).any().item():
+ raise ValueError(
+ "This LoRA weight seems to be broken. "
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
+ "LoRA weights will not be fused."
+ )
+
+ self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
+
+ # we can drop the lora layer now
+ self.lora_linear_layer = None
+
+ # offload the up and down matrices to CPU to not blow the memory
+ self.w_up = w_up.cpu()
+ self.w_down = w_down.cpu()
+ self.lora_scale = lora_scale
+
+ def _unfuse_lora(self):
+ if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
+ return
+
+ fused_weight = self.regular_linear_layer.weight.data
+ dtype, device = fused_weight.dtype, fused_weight.device
+
+ w_up = self.w_up.to(device=device).float()
+ w_down = self.w_down.to(device).float()
+
+ unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
+ self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
+
+ self.w_up = None
+ self.w_down = None
+
+ def forward(self, input):
+ if self.lora_scale is None:
+ self.lora_scale = 1.0
+ if self.lora_linear_layer is None:
+ return self.regular_linear_layer(input)
+ return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))
+
+
+def text_encoder_attn_modules(text_encoder):
+ attn_modules = []
+
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
+ name = f"text_model.encoder.layers.{i}.self_attn"
+ mod = layer.self_attn
+ attn_modules.append((name, mod))
+ else:
+ raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
+
+ return attn_modules
+
+
+def text_encoder_mlp_modules(text_encoder):
+ mlp_modules = []
+
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
+ mlp_mod = layer.mlp
+ name = f"text_model.encoder.layers.{i}.mlp"
+ mlp_modules.append((name, mlp_mod))
+ else:
+ raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
+
+ return mlp_modules
+
+
+def text_encoder_lora_state_dict(text_encoder):
+ state_dict = {}
+
+ for name, module in text_encoder_attn_modules(text_encoder):
+ for k, v in module.q_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.k_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.v_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
+
+ for k, v in module.out_proj.lora_linear_layer.state_dict().items():
+ state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
+
+ return state_dict
+
+
+class AttnProcsLayers(torch.nn.Module):
+ def __init__(self, state_dict: Dict[str, torch.Tensor]):
+ super().__init__()
+ self.layers = torch.nn.ModuleList(state_dict.values())
+ self.mapping = dict(enumerate(state_dict.keys()))
+ self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
+
+ # .processor for unet, .self_attn for text encoder
+ self.split_keys = [".processor", ".self_attn"]
+
+ # we add a hook to state_dict() and load_state_dict() so that the
+ # naming fits with `unet.attn_processors`
+ def map_to(module, state_dict, *args, **kwargs):
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ num = int(key.split(".")[1]) # 0 is always "layers"
+ new_key = key.replace(f"layers.{num}", module.mapping[num])
+ new_state_dict[new_key] = value
+
+ return new_state_dict
+
+ def remap_key(key, state_dict):
+ for k in self.split_keys:
+ if k in key:
+ return key.split(k)[0] + k
+
+ raise ValueError(
+ f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
+ )
+
+ def map_from(module, state_dict, *args, **kwargs):
+ all_keys = list(state_dict.keys())
+ for key in all_keys:
+ replace_key = remap_key(key, state_dict)
+ new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
+ state_dict[new_key] = state_dict[key]
+ del state_dict[key]
+
+ self._register_state_dict_hook(map_to)
+ self._register_load_state_dict_pre_hook(map_from, with_module=True)
+
+
+class UNet2DConditionLoadersMixin:
+ text_encoder_name = TEXT_ENCODER_NAME
+ unet_name = UNET_NAME
+
+ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
+ r"""
+ Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
+ defined in
+ [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
+ and be a `torch.nn.Module` class.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a directory (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+
+ """
+ from .models.attention_processor import (
+ CustomDiffusionAttnProcessor,
+ )
+ from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
+
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
+ network_alphas = kwargs.pop("network_alphas", None)
+
+ _pipeline = kwargs.pop("_pipeline", None)
+
+ is_network_alphas_none = network_alphas is None
+
+ allow_pickle = False
+
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ model_file = None
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ # Let's first try to load .safetensors weights
+ if (use_safetensors and weight_name is None) or (
+ weight_name is not None and weight_name.endswith(".safetensors")
+ ):
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ except IOError as e:
+ if not allow_pickle:
+ raise e
+ # try loading non-safetensors weights
+ pass
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = torch.load(model_file, map_location="cpu")
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ # fill attn processors
+ lora_layers_list = []
+
+ is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
+ is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
+
+ if is_lora:
+ # correct keys
+ state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
+
+ if network_alphas is not None:
+ network_alphas_keys = list(network_alphas.keys())
+ used_network_alphas_keys = set()
+
+ lora_grouped_dict = defaultdict(dict)
+ mapped_network_alphas = {}
+
+ all_keys = list(state_dict.keys())
+ for key in all_keys:
+ value = state_dict.pop(key)
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
+ lora_grouped_dict[attn_processor_key][sub_key] = value
+
+ # Create another `mapped_network_alphas` dictionary so that we can properly map them.
+ if network_alphas is not None:
+ for k in network_alphas_keys:
+ if k.replace(".alpha", "") in key:
+ mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
+ used_network_alphas_keys.add(k)
+
+ if not is_network_alphas_none:
+ if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
+ raise ValueError(
+ f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
+ )
+
+ if len(state_dict) > 0:
+ raise ValueError(
+ f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
+ )
+
+ for key, value_dict in lora_grouped_dict.items():
+ attn_processor = self
+ for sub_key in key.split("."):
+ attn_processor = getattr(attn_processor, sub_key)
+
+ # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
+ # or add_{k,v,q,out_proj}_proj_lora layers.
+ rank = value_dict["lora.down.weight"].shape[0]
+
+ if isinstance(attn_processor, LoRACompatibleConv):
+ in_features = attn_processor.in_channels
+ out_features = attn_processor.out_channels
+ kernel_size = attn_processor.kernel_size
+
+ ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
+ with ctx():
+ lora = LoRAConv2dLayer(
+ in_features=in_features,
+ out_features=out_features,
+ rank=rank,
+ kernel_size=kernel_size,
+ stride=attn_processor.stride,
+ padding=attn_processor.padding,
+ network_alpha=mapped_network_alphas.get(key),
+ )
+ elif isinstance(attn_processor, LoRACompatibleLinear):
+ ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
+ with ctx():
+ lora = LoRALinearLayer(
+ attn_processor.in_features,
+ attn_processor.out_features,
+ rank,
+ mapped_network_alphas.get(key),
+ )
+ else:
+ raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
+
+ value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
+ lora_layers_list.append((attn_processor, lora))
+
+ if low_cpu_mem_usage:
+ device = next(iter(value_dict.values())).device
+ dtype = next(iter(value_dict.values())).dtype
+ load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
+ else:
+ lora.load_state_dict(value_dict)
+
+ elif is_custom_diffusion:
+ attn_processors = {}
+ custom_diffusion_grouped_dict = defaultdict(dict)
+ for key, value in state_dict.items():
+ if len(value) == 0:
+ custom_diffusion_grouped_dict[key] = {}
+ else:
+ if "to_out" in key:
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
+ else:
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
+ custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
+
+ for key, value_dict in custom_diffusion_grouped_dict.items():
+ if len(value_dict) == 0:
+ attn_processors[key] = CustomDiffusionAttnProcessor(
+ train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
+ )
+ else:
+ cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
+ hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
+ train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
+ attn_processors[key] = CustomDiffusionAttnProcessor(
+ train_kv=True,
+ train_q_out=train_q_out,
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ )
+ attn_processors[key].load_state_dict(value_dict)
+ elif USE_PEFT_BACKEND:
+ # In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`
+ # on the Unet
+ pass
+ else:
+ raise ValueError(
+ f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
+ )
+
+ #
+
+ def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
+ is_new_lora_format = all(
+ key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
+ )
+ if is_new_lora_format:
+ # Strip the `"unet"` prefix.
+ is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
+ if is_text_encoder_present:
+ warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
+ logger.warn(warn_message)
+ unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
+ state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
+
+ # change processor format to 'pure' LoRACompatibleLinear format
+ if any("processor" in k.split(".") for k in state_dict.keys()):
+
+ def format_to_lora_compatible(key):
+ if "processor" not in key.split("."):
+ return key
+ return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora")
+
+ state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()}
+
+ if network_alphas is not None:
+ network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
+ return state_dict, network_alphas
+
+ def save_attn_procs(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ **kwargs,
+ ):
+ r"""
+ Save an attention processor to a directory so that it can be reloaded using the
+ [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save an attention processor to. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ from .models.attention_processor import (
+ CustomDiffusionAttnProcessor,
+ CustomDiffusionAttnProcessor2_0,
+ CustomDiffusionXFormersAttnProcessor,
+ )
+
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ if save_function is None:
+ if safe_serialization:
+
+ def save_function(weights, filename):
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
+
+ else:
+ save_function = torch.save
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ is_custom_diffusion = any(
+ isinstance(
+ x,
+ (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
+ )
+ for (_, x) in self.attn_processors.items()
+ )
+ if is_custom_diffusion:
+ model_to_save = AttnProcsLayers(
+ {
+ y: x
+ for (y, x) in self.attn_processors.items()
+ if isinstance(
+ x,
+ (
+ CustomDiffusionAttnProcessor,
+ CustomDiffusionAttnProcessor2_0,
+ CustomDiffusionXFormersAttnProcessor,
+ ),
+ )
+ }
+ )
+ state_dict = model_to_save.state_dict()
+ for name, attn in self.attn_processors.items():
+ if len(attn.state_dict()) == 0:
+ state_dict[name] = {}
+ else:
+ model_to_save = AttnProcsLayers(self.attn_processors)
+ state_dict = model_to_save.state_dict()
+
+ if weight_name is None:
+ if safe_serialization:
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
+ else:
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
+
+ # Save the model
+ save_function(state_dict, os.path.join(save_directory, weight_name))
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
+
+ def fuse_lora(self, lora_scale=1.0, safe_fusing=False):
+ self.lora_scale = lora_scale
+ self._safe_fusing = safe_fusing
+ self.apply(self._fuse_lora_apply)
+
+ def _fuse_lora_apply(self, module):
+ if not USE_PEFT_BACKEND:
+ if hasattr(module, "_fuse_lora"):
+ module._fuse_lora(self.lora_scale, self._safe_fusing)
+ else:
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ if isinstance(module, BaseTunerLayer):
+ if self.lora_scale != 1.0:
+ module.scale_layer(self.lora_scale)
+ module.merge(safe_merge=self._safe_fusing)
+
+ def unfuse_lora(self):
+ self.apply(self._unfuse_lora_apply)
+
+ def _unfuse_lora_apply(self, module):
+ if not USE_PEFT_BACKEND:
+ if hasattr(module, "_unfuse_lora"):
+ module._unfuse_lora()
+ else:
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ if isinstance(module, BaseTunerLayer):
+ module.unmerge()
+
+ def set_adapters(
+ self,
+ adapter_names: Union[List[str], str],
+ weights: Optional[Union[List[float], float]] = None,
+ ):
+ """
+ Sets the adapter layers for the unet.
+
+ Args:
+ adapter_names (`List[str]` or `str`):
+ The names of the adapters to use.
+ weights (`Union[List[float], float]`, *optional*):
+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
+ adapters.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for `set_adapters()`.")
+
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
+
+ if weights is None:
+ weights = [1.0] * len(adapter_names)
+ elif isinstance(weights, float):
+ weights = [weights] * len(adapter_names)
+
+ if len(adapter_names) != len(weights):
+ raise ValueError(
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
+ )
+
+ set_weights_and_activate_adapters(self, adapter_names, weights)
+
+ def disable_lora(self):
+ """
+ Disables the active LoRA layers for the unet.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+ set_adapter_layers(self, enabled=False)
+
+ def enable_lora(self):
+ """
+ Enables the active LoRA layers for the unet.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+ set_adapter_layers(self, enabled=True)
+
+
+def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "text_inversion",
+ "framework": "pytorch",
+ }
+ state_dicts = []
+ for pretrained_model_name_or_path in pretrained_model_name_or_paths:
+ if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
+ # 3.1. Load textual inversion file
+ model_file = None
+
+ # Let's first try to load .safetensors weights
+ if (use_safetensors and weight_name is None) or (
+ weight_name is not None and weight_name.endswith(".safetensors")
+ ):
+ try:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ except Exception as e:
+ if not allow_pickle:
+ raise e
+
+ model_file = None
+
+ if model_file is None:
+ model_file = _get_model_file(
+ pretrained_model_name_or_path,
+ weights_name=weight_name or TEXT_INVERSION_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = torch.load(model_file, map_location="cpu")
+ else:
+ state_dict = pretrained_model_name_or_path
+
+ state_dicts.append(state_dict)
+
+ return state_dicts
+
+
+class TextualInversionLoaderMixin:
+ r"""
+ Load textual inversion tokens and embeddings to the tokenizer and text encoder.
+ """
+
+ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
+ r"""
+ Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
+ be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
+ inversion token or if the textual inversion token is a single vector, the input prompt is returned.
+
+ Parameters:
+ prompt (`str` or list of `str`):
+ The prompt or prompts to guide the image generation.
+ tokenizer (`PreTrainedTokenizer`):
+ The tokenizer responsible for encoding the prompt into input tokens.
+
+ Returns:
+ `str` or list of `str`: The converted prompt
+ """
+ if not isinstance(prompt, List):
+ prompts = [prompt]
+ else:
+ prompts = prompt
+
+ prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
+
+ if not isinstance(prompt, List):
+ return prompts[0]
+
+ return prompts
+
+ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
+ r"""
+ Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
+ to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
+ is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
+ inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
+
+ Parameters:
+ prompt (`str`):
+ The prompt to guide the image generation.
+ tokenizer (`PreTrainedTokenizer`):
+ The tokenizer responsible for encoding the prompt into input tokens.
+
+ Returns:
+ `str`: The converted prompt
+ """
+ tokens = tokenizer.tokenize(prompt)
+ unique_tokens = set(tokens)
+ for token in unique_tokens:
+ if token in tokenizer.added_tokens_encoder:
+ replacement = token
+ i = 1
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
+ replacement += f" {token}_{i}"
+ i += 1
+
+ prompt = prompt.replace(token, replacement)
+
+ return prompt
+
+ def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens):
+ if tokenizer is None:
+ raise ValueError(
+ f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
+ f" `{self.load_textual_inversion.__name__}`"
+ )
+
+ if text_encoder is None:
+ raise ValueError(
+ f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
+ f" `{self.load_textual_inversion.__name__}`"
+ )
+
+ if len(pretrained_model_name_or_paths) != len(tokens):
+ raise ValueError(
+ f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} "
+ f"Make sure both lists have the same length."
+ )
+
+ valid_tokens = [t for t in tokens if t is not None]
+ if len(set(valid_tokens)) < len(valid_tokens):
+ raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
+
+ @staticmethod
+ def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer):
+ all_tokens = []
+ all_embeddings = []
+ for state_dict, token in zip(state_dicts, tokens):
+ if isinstance(state_dict, torch.Tensor):
+ if token is None:
+ raise ValueError(
+ "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
+ )
+ loaded_token = token
+ embedding = state_dict
+ elif len(state_dict) == 1:
+ # diffusers
+ loaded_token, embedding = next(iter(state_dict.items()))
+ elif "string_to_param" in state_dict:
+ # A1111
+ loaded_token = state_dict["name"]
+ embedding = state_dict["string_to_param"]["*"]
+ else:
+ raise ValueError(
+ f"Loaded state dictonary is incorrect: {state_dict}. \n\n"
+ "Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`"
+ " input key."
+ )
+
+ if token is not None and loaded_token != token:
+ logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
+ else:
+ token = loaded_token
+
+ if token in tokenizer.get_vocab():
+ raise ValueError(
+ f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
+ )
+
+ all_tokens.append(token)
+ all_embeddings.append(embedding)
+
+ return all_tokens, all_embeddings
+
+ @staticmethod
+ def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer):
+ all_tokens = []
+ all_embeddings = []
+
+ for embedding, token in zip(embeddings, tokens):
+ if f"{token}_1" in tokenizer.get_vocab():
+ multi_vector_tokens = [token]
+ i = 1
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
+ multi_vector_tokens.append(f"{token}_{i}")
+ i += 1
+
+ raise ValueError(
+ f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
+ )
+
+ is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
+ if is_multi_vector:
+ all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
+ all_embeddings += [e for e in embedding] # noqa: C416
+ else:
+ all_tokens += [token]
+ all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding]
+
+ return all_tokens, all_embeddings
+
+ def load_textual_inversion(
+ self,
+ pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
+ token: Optional[Union[str, List[str]]] = None,
+ tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
+ text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
+ **kwargs,
+ ):
+ r"""
+ Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
+ Automatic1111 formats are supported).
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
+ Can be either one of the following or a list of them:
+
+ - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
+ pretrained model hosted on the Hub.
+ - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
+ inversion weights.
+ - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ token (`str` or `List[str]`, *optional*):
+ Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
+ list, then `token` must also be a list of equal length.
+ text_encoder ([`~transformers.CLIPTextModel`], *optional*):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ If not specified, function will take self.tokenizer.
+ tokenizer ([`~transformers.CLIPTokenizer`], *optional*):
+ A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer.
+ weight_name (`str`, *optional*):
+ Name of a custom weight file. This should be used when:
+
+ - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
+ name such as `text_inv.bin`.
+ - The saved textual inversion file is in the Automatic1111 format.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+
+ Example:
+
+ To load a textual inversion embedding vector in 🤗 Diffusers format:
+
+ ```py
+ from diffusers import StableDiffusionPipeline
+ import torch
+
+ model_id = "runwayml/stable-diffusion-v1-5"
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+ pipe.load_textual_inversion("sd-concepts-library/cat-toy")
+
+ prompt = "A backpack"
+
+ image = pipe(prompt, num_inference_steps=50).images[0]
+ image.save("cat-backpack.png")
+ ```
+
+ To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first
+ (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
+ locally:
+
+ ```py
+ from diffusers import StableDiffusionPipeline
+ import torch
+
+ model_id = "runwayml/stable-diffusion-v1-5"
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
+
+ pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
+
+ prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
+
+ image = pipe(prompt, num_inference_steps=50).images[0]
+ image.save("character.png")
+ ```
+
+ """
+ # 1. Set correct tokenizer and text encoder
+ tokenizer = tokenizer or getattr(self, "tokenizer", None)
+ text_encoder = text_encoder or getattr(self, "text_encoder", None)
+
+ # 2. Normalize inputs
+ pretrained_model_name_or_paths = (
+ [pretrained_model_name_or_path]
+ if not isinstance(pretrained_model_name_or_path, list)
+ else pretrained_model_name_or_path
+ )
+ tokens = len(pretrained_model_name_or_paths) * [token] if (isinstance(token, str) or token is None) else token
+
+ # 3. Check inputs
+ self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)
+
+ # 4. Load state dicts of textual embeddings
+ state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
+
+ # 4. Retrieve tokens and embeddings
+ tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)
+
+ # 5. Extend tokens and embeddings for multi vector
+ tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer)
+
+ # 6. Make sure all embeddings have the correct size
+ expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1]
+ if any(expected_emb_dim != emb.shape[-1] for emb in embeddings):
+ raise ValueError(
+ "Loaded embeddings are of incorrect shape. Expected each textual inversion embedding "
+ "to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} "
+ )
+
+ # 7. Now we can be sure that loading the embedding matrix works
+ # < Unsafe code:
+
+ # 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
+ is_model_cpu_offload = False
+ is_sequential_cpu_offload = False
+ for _, component in self.components.items():
+ if isinstance(component, nn.Module):
+ if hasattr(component, "_hf_hook"):
+ is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
+ is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
+ logger.info(
+ "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
+ )
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
+
+ # 7.2 save expected device and dtype
+ device = text_encoder.device
+ dtype = text_encoder.dtype
+
+ # 7.3 Increase token embedding matrix
+ text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens))
+ input_embeddings = text_encoder.get_input_embeddings().weight
+
+ # 7.4 Load token and embedding
+ for token, embedding in zip(tokens, embeddings):
+ # add tokens and get ids
+ tokenizer.add_tokens(token)
+ token_id = tokenizer.convert_tokens_to_ids(token)
+ input_embeddings.data[token_id] = embedding
+ logger.info(f"Loaded textual inversion embedding for {token}.")
+
+ input_embeddings.to(dtype=dtype, device=device)
+
+ # 7.5 Offload the model again
+ if is_model_cpu_offload:
+ self.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ self.enable_sequential_cpu_offload()
+
+ # / Unsafe Code >
+
+
+class LoraLoaderMixin:
+ r"""
+ Load LoRA layers into [`UNet2DConditionModel`] and
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
+ """
+ text_encoder_name = TEXT_ENCODER_NAME
+ unet_name = UNET_NAME
+ num_fused_loras = 0
+
+ def load_lora_weights(
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
+ `self.text_encoder`.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+
+ See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into
+ `self.unet`.
+
+ See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded
+ into `self.text_encoder`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+ self.load_lora_into_unet(
+ state_dict,
+ network_alphas=network_alphas,
+ unet=self.unet,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+ self.load_lora_into_text_encoder(
+ state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder,
+ lora_scale=self.lora_scale,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ @classmethod
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # UNet and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ unet_config = kwargs.pop("unet_config", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+
+ model_file = None
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ # Let's first try to load .safetensors weights
+ if (use_safetensors and weight_name is None) or (
+ weight_name is not None and weight_name.endswith(".safetensors")
+ ):
+ try:
+ # Here we're relaxing the loading check to enable more Inference API
+ # friendliness where sometimes, it's not at all possible to automatically
+ # determine `weight_name`.
+ if weight_name is None:
+ weight_name = cls._best_guess_weight_name(
+ pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
+ )
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ except (IOError, safetensors.SafetensorError) as e:
+ if not allow_pickle:
+ raise e
+ # try loading non-safetensors weights
+ model_file = None
+ pass
+
+ if model_file is None:
+ if weight_name is None:
+ weight_name = cls._best_guess_weight_name(
+ pretrained_model_name_or_path_or_dict, file_extension=".bin"
+ )
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name or LORA_WEIGHT_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ state_dict = torch.load(model_file, map_location="cpu")
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ network_alphas = None
+ # TODO: replace it with a method from `state_dict_utils`
+ if all(
+ (
+ k.startswith("lora_te_")
+ or k.startswith("lora_unet_")
+ or k.startswith("lora_te1_")
+ or k.startswith("lora_te2_")
+ )
+ for k in state_dict.keys()
+ ):
+ # Map SDXL blocks correctly.
+ if unet_config is not None:
+ # use unet config to remap block numbers
+ state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
+ state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
+
+ return state_dict, network_alphas
+
+ @classmethod
+ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
+ targeted_files = []
+
+ if os.path.isfile(pretrained_model_name_or_path_or_dict):
+ return
+ elif os.path.isdir(pretrained_model_name_or_path_or_dict):
+ targeted_files = [
+ f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
+ ]
+ else:
+ files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
+ targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
+ if len(targeted_files) == 0:
+ return
+
+ # "scheduler" does not correspond to a LoRA checkpoint.
+ # "optimizer" does not correspond to a LoRA checkpoint
+ # only top-level checkpoints are considered and not the other ones, hence "checkpoint".
+ unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
+ targeted_files = list(
+ filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
+ )
+
+ if len(targeted_files) > 1:
+ raise ValueError(
+ f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
+ )
+ weight_name = targeted_files[0]
+ return weight_name
+
+ @classmethod
+ def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
+ # 1. get all state_dict_keys
+ all_keys = list(state_dict.keys())
+ sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
+
+ # 2. check if needs remapping, if not return original dict
+ is_in_sgm_format = False
+ for key in all_keys:
+ if any(p in key for p in sgm_patterns):
+ is_in_sgm_format = True
+ break
+
+ if not is_in_sgm_format:
+ return state_dict
+
+ # 3. Else remap from SGM patterns
+ new_state_dict = {}
+ inner_block_map = ["resnets", "attentions", "upsamplers"]
+
+ # Retrieves # of down, mid and up blocks
+ input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
+
+ for layer in all_keys:
+ if "text" in layer:
+ new_state_dict[layer] = state_dict.pop(layer)
+ else:
+ layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
+ if sgm_patterns[0] in layer:
+ input_block_ids.add(layer_id)
+ elif sgm_patterns[1] in layer:
+ middle_block_ids.add(layer_id)
+ elif sgm_patterns[2] in layer:
+ output_block_ids.add(layer_id)
+ else:
+ raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
+
+ input_blocks = {
+ layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
+ for layer_id in input_block_ids
+ }
+ middle_blocks = {
+ layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
+ for layer_id in middle_block_ids
+ }
+ output_blocks = {
+ layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
+ for layer_id in output_block_ids
+ }
+
+ # Rename keys accordingly
+ for i in input_block_ids:
+ block_id = (i - 1) // (unet_config.layers_per_block + 1)
+ layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
+
+ for key in input_blocks[i]:
+ inner_block_id = int(key.split(delimiter)[block_slice_pos])
+ inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
+ inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
+ new_key = delimiter.join(
+ key.split(delimiter)[: block_slice_pos - 1]
+ + [str(block_id), inner_block_key, inner_layers_in_block]
+ + key.split(delimiter)[block_slice_pos + 1 :]
+ )
+ new_state_dict[new_key] = state_dict.pop(key)
+
+ for i in middle_block_ids:
+ key_part = None
+ if i == 0:
+ key_part = [inner_block_map[0], "0"]
+ elif i == 1:
+ key_part = [inner_block_map[1], "0"]
+ elif i == 2:
+ key_part = [inner_block_map[0], "1"]
+ else:
+ raise ValueError(f"Invalid middle block id {i}.")
+
+ for key in middle_blocks[i]:
+ new_key = delimiter.join(
+ key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
+ )
+ new_state_dict[new_key] = state_dict.pop(key)
+
+ for i in output_block_ids:
+ block_id = i // (unet_config.layers_per_block + 1)
+ layer_in_block_id = i % (unet_config.layers_per_block + 1)
+
+ for key in output_blocks[i]:
+ inner_block_id = int(key.split(delimiter)[block_slice_pos])
+ inner_block_key = inner_block_map[inner_block_id]
+ inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
+ new_key = delimiter.join(
+ key.split(delimiter)[: block_slice_pos - 1]
+ + [str(block_id), inner_block_key, inner_layers_in_block]
+ + key.split(delimiter)[block_slice_pos + 1 :]
+ )
+ new_state_dict[new_key] = state_dict.pop(key)
+
+ if len(state_dict) > 0:
+ raise ValueError("At this point all state dict entries have to be converted.")
+
+ return new_state_dict
+
+ @classmethod
+ def _optionally_disable_offloading(cls, _pipeline):
+ """
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
+
+ Args:
+ _pipeline (`DiffusionPipeline`):
+ The pipeline to disable offloading for.
+
+ Returns:
+ tuple:
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
+ """
+ is_model_cpu_offload = False
+ is_sequential_cpu_offload = False
+
+ if _pipeline is not None:
+ for _, component in _pipeline.components.items():
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
+ if not is_model_cpu_offload:
+ is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
+ if not is_sequential_cpu_offload:
+ is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)
+
+ logger.info(
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
+ )
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
+
+ return (is_model_cpu_offload, is_sequential_cpu_offload)
+
+ @classmethod
+ def load_lora_into_unet(
+ cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `unet`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ network_alphas (`Dict[str, float]`):
+ See `LoRALinearLayer` for more details.
+ unet (`UNet2DConditionModel`):
+ The UNet model to load the LoRA layers into.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+
+ if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
+ # Load the layers corresponding to UNet.
+ logger.info(f"Loading {cls.unet_name}.")
+
+ unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
+ state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
+
+ if network_alphas is not None:
+ alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
+ network_alphas = {
+ k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
+ }
+
+ else:
+ # Otherwise, we're dealing with the old format. This means the `state_dict` should only
+ # contain the module names of the `unet` as its keys WITHOUT any prefix.
+ warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
+ logger.warn(warn_message)
+
+ if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
+ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
+
+ if adapter_name in getattr(unet, "peft_config", {}):
+ raise ValueError(
+ f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
+ )
+
+ state_dict = convert_unet_state_dict_to_peft(state_dict)
+
+ if network_alphas is not None:
+ # The alphas state dict have the same structure as Unet, thus we convert it to peft format using
+ # `convert_unet_state_dict_to_peft` method.
+ network_alphas = convert_unet_state_dict_to_peft(network_alphas)
+
+ rank = {}
+ for key, val in state_dict.items():
+ if "lora_B" in key:
+ rank[key] = val.shape[1]
+
+ lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(unet)
+
+ # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
+ # otherwise loading LoRA weights will lead to an error
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
+ incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)
+
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Offload back.
+ if is_model_cpu_offload:
+ _pipeline.enable_model_cpu_offload()
+ elif is_sequential_cpu_offload:
+ _pipeline.enable_sequential_cpu_offload()
+ # Unsafe code />
+
+ unet.load_attn_procs(
+ state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
+ )
+
+ @classmethod
+ def load_lora_into_text_encoder(
+ cls,
+ state_dict,
+ network_alphas,
+ text_encoder,
+ prefix=None,
+ lora_scale=1.0,
+ low_cpu_mem_usage=None,
+ adapter_name=None,
+ _pipeline=None,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
+ additional `text_encoder` to distinguish between unet lora layers.
+ network_alphas (`Dict[str, float]`):
+ See `LoRALinearLayer` for more details.
+ text_encoder (`CLIPTextModel`):
+ The text encoder model to load the LoRA layers into.
+ prefix (`str`):
+ Expected prefix of the `text_encoder` in the `state_dict`.
+ lora_scale (`float`):
+ How much to scale the output of the lora linear layer before it is added with the output of the regular
+ lora layer.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ """
+ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
+
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
+ # their prefixes.
+ keys = list(state_dict.keys())
+ prefix = cls.text_encoder_name if prefix is None else prefix
+
+ # Safe prefix to check with.
+ if any(cls.text_encoder_name in key for key in keys):
+ # Load the layers corresponding to text encoder and make necessary adjustments.
+ text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
+ text_encoder_lora_state_dict = {
+ k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
+ }
+
+ if len(text_encoder_lora_state_dict) > 0:
+ logger.info(f"Loading {prefix}.")
+ rank = {}
+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
+
+ if USE_PEFT_BACKEND:
+ # convert state dict
+ text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
+
+ for name, _ in text_encoder_attn_modules(text_encoder):
+ rank_key = f"{name}.out_proj.lora_B.weight"
+ rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
+
+ patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
+ if patch_mlp:
+ for name, _ in text_encoder_mlp_modules(text_encoder):
+ rank_key_fc1 = f"{name}.fc1.lora_B.weight"
+ rank_key_fc2 = f"{name}.fc2.lora_B.weight"
+
+ rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
+ rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
+ else:
+ for name, _ in text_encoder_attn_modules(text_encoder):
+ rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
+ rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
+
+ patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
+ if patch_mlp:
+ for name, _ in text_encoder_mlp_modules(text_encoder):
+ rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
+ rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
+ rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
+ rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
+
+ if network_alphas is not None:
+ alpha_keys = [
+ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
+ ]
+ network_alphas = {
+ k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
+ }
+
+ if USE_PEFT_BACKEND:
+ from peft import LoraConfig
+
+ lora_config_kwargs = get_peft_kwargs(
+ rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
+ )
+
+ lora_config = LoraConfig(**lora_config_kwargs)
+
+ # adapter_name
+ if adapter_name is None:
+ adapter_name = get_adapter_name(text_encoder)
+
+ is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
+
+ # inject LoRA layers and load the state dict
+ # in transformers we automatically check whether the adapter name is already in use or not
+ text_encoder.load_adapter(
+ adapter_name=adapter_name,
+ adapter_state_dict=text_encoder_lora_state_dict,
+ peft_config=lora_config,
+ )
+
+ # scale LoRA layers with `lora_scale`
+ scale_lora_layers(text_encoder, weight=lora_scale)
+ else:
+ cls._modify_text_encoder(
+ text_encoder,
+ lora_scale,
+ network_alphas,
+ rank=rank,
+ patch_mlp=patch_mlp,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ is_pipeline_offloaded = _pipeline is not None and any(
+ isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
+ for c in _pipeline.components.values()
+ )
+ if is_pipeline_offloaded and low_cpu_mem_usage:
+ low_cpu_mem_usage = True
+ logger.info(
+ f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
+ )
+
+ if low_cpu_mem_usage:
+ device = next(iter(text_encoder_lora_state_dict.values())).device
+ dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
+ unexpected_keys = load_model_dict_into_meta(
+ text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
+ )
+ else:
+ load_state_dict_results = text_encoder.load_state_dict(
+ text_encoder_lora_state_dict, strict=False
+ )
+ unexpected_keys = load_state_dict_results.unexpected_keys
+
+ if len(unexpected_keys) != 0:
+ raise ValueError(
+ f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
+ )
+
+ #
+
+ @property
+ def lora_scale(self) -> float:
+ # property function that returns the lora scale which can be set at run time by the pipeline.
+ # if _lora_scale has not been set, return 1
+ return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
+
+ def _remove_text_encoder_monkey_patch(self):
+ if USE_PEFT_BACKEND:
+ remove_method = recurse_remove_peft_layers
+ else:
+ remove_method = self._remove_text_encoder_monkey_patch_classmethod
+
+ if hasattr(self, "text_encoder"):
+ remove_method(self.text_encoder)
+
+ # In case text encoder have no Lora attached
+ if USE_PEFT_BACKEND and getattr(self.text_encoder, "peft_config", None) is not None:
+ del self.text_encoder.peft_config
+ self.text_encoder._hf_peft_config_loaded = None
+ if hasattr(self, "text_encoder_2"):
+ remove_method(self.text_encoder_2)
+ if USE_PEFT_BACKEND:
+ del self.text_encoder_2.peft_config
+ self.text_encoder_2._hf_peft_config_loaded = None
+
+ @classmethod
+ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
+ if version.parse(__version__) > version.parse("0.23"):
+ deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.25", LORA_DEPRECATION_MESSAGE)
+
+ for _, attn_module in text_encoder_attn_modules(text_encoder):
+ if isinstance(attn_module.q_proj, PatchedLoraProjection):
+ attn_module.q_proj.lora_linear_layer = None
+ attn_module.k_proj.lora_linear_layer = None
+ attn_module.v_proj.lora_linear_layer = None
+ attn_module.out_proj.lora_linear_layer = None
+
+ for _, mlp_module in text_encoder_mlp_modules(text_encoder):
+ if isinstance(mlp_module.fc1, PatchedLoraProjection):
+ mlp_module.fc1.lora_linear_layer = None
+ mlp_module.fc2.lora_linear_layer = None
+
+ @classmethod
+ def _modify_text_encoder(
+ cls,
+ text_encoder,
+ lora_scale=1,
+ network_alphas=None,
+ rank: Union[Dict[str, int], int] = 4,
+ dtype=None,
+ patch_mlp=False,
+ low_cpu_mem_usage=False,
+ ):
+ r"""
+ Monkey-patches the forward passes of attention modules of the text encoder.
+ """
+ if version.parse(__version__) > version.parse("0.23"):
+ deprecate("_modify_text_encoder", "0.25", LORA_DEPRECATION_MESSAGE)
+
+ def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
+ linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
+ ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
+ with ctx():
+ model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)
+
+ lora_parameters.extend(model.lora_linear_layer.parameters())
+ return model
+
+ # First, remove any monkey-patch that might have been applied before
+ cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
+
+ lora_parameters = []
+ network_alphas = {} if network_alphas is None else network_alphas
+ is_network_alphas_populated = len(network_alphas) > 0
+
+ for name, attn_module in text_encoder_attn_modules(text_encoder):
+ query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
+ key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
+ value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
+ out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
+
+ if isinstance(rank, dict):
+ current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight")
+ else:
+ current_rank = rank
+
+ attn_module.q_proj = create_patched_linear_lora(
+ attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters
+ )
+ attn_module.k_proj = create_patched_linear_lora(
+ attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters
+ )
+ attn_module.v_proj = create_patched_linear_lora(
+ attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters
+ )
+ attn_module.out_proj = create_patched_linear_lora(
+ attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters
+ )
+
+ if patch_mlp:
+ for name, mlp_module in text_encoder_mlp_modules(text_encoder):
+ fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None)
+ fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None)
+
+ current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
+ current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
+
+ mlp_module.fc1 = create_patched_linear_lora(
+ mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters
+ )
+ mlp_module.fc2 = create_patched_linear_lora(
+ mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters
+ )
+
+ if is_network_alphas_populated and len(network_alphas) > 0:
+ raise ValueError(
+ f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
+ )
+
+ return lora_parameters
+
+ @classmethod
+ def save_lora_weights(
+ self,
+ save_directory: Union[str, os.PathLike],
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `unet`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from 🤗 Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ # Create a flat dictionary.
+ state_dict = {}
+
+ # Populate the dictionary.
+ if unet_lora_layers is not None:
+ weights = (
+ unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
+ )
+
+ unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
+ state_dict.update(unet_lora_state_dict)
+
+ if text_encoder_lora_layers is not None:
+ weights = (
+ text_encoder_lora_layers.state_dict()
+ if isinstance(text_encoder_lora_layers, torch.nn.Module)
+ else text_encoder_lora_layers
+ )
+
+ text_encoder_lora_state_dict = {
+ f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
+ }
+ state_dict.update(text_encoder_lora_state_dict)
+
+ # Save the model
+ self.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def write_lora_layers(
+ state_dict: Dict[str, torch.Tensor],
+ save_directory: str,
+ is_main_process: bool,
+ weight_name: str,
+ save_function: Callable,
+ safe_serialization: bool,
+ ):
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ if save_function is None:
+ if safe_serialization:
+
+ def save_function(weights, filename):
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
+
+ else:
+ save_function = torch.save
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if weight_name is None:
+ if safe_serialization:
+ weight_name = LORA_WEIGHT_NAME_SAFE
+ else:
+ weight_name = LORA_WEIGHT_NAME
+
+ save_function(state_dict, os.path.join(save_directory, weight_name))
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
+
+ @classmethod
+ def _convert_kohya_lora_to_diffusers(cls, state_dict):
+ unet_state_dict = {}
+ te_state_dict = {}
+ te2_state_dict = {}
+ network_alphas = {}
+
+ # every down weight has a corresponding up weight and potentially an alpha weight
+ lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
+ for key in lora_keys:
+ lora_name = key.split(".")[0]
+ lora_name_up = lora_name + ".lora_up.weight"
+ lora_name_alpha = lora_name + ".alpha"
+
+ if lora_name.startswith("lora_unet_"):
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
+
+ if "input.blocks" in diffusers_name:
+ diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
+ else:
+ diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
+
+ if "middle.block" in diffusers_name:
+ diffusers_name = diffusers_name.replace("middle.block", "mid_block")
+ else:
+ diffusers_name = diffusers_name.replace("mid.block", "mid_block")
+ if "output.blocks" in diffusers_name:
+ diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
+ else:
+ diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
+
+ diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
+ diffusers_name = diffusers_name.replace("proj.in", "proj_in")
+ diffusers_name = diffusers_name.replace("proj.out", "proj_out")
+ diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
+
+ # SDXL specificity.
+ if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
+ pattern = r"\.\d+(?=\D*$)"
+ diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
+ if ".in." in diffusers_name:
+ diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
+ if ".out." in diffusers_name:
+ diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
+ if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
+ diffusers_name = diffusers_name.replace("op", "conv")
+ if "skip" in diffusers_name:
+ diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
+
+ # LyCORIS specificity.
+ if "time.emb.proj" in diffusers_name:
+ diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
+ if "conv.shortcut" in diffusers_name:
+ diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
+
+ # General coverage.
+ if "transformer_blocks" in diffusers_name:
+ if "attn1" in diffusers_name or "attn2" in diffusers_name:
+ diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
+ diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+ elif "ff" in diffusers_name:
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+ elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+ else:
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+
+ elif lora_name.startswith("lora_te_"):
+ diffusers_name = key.replace("lora_te_", "").replace("_", ".")
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
+ if "self_attn" in diffusers_name:
+ te_state_dict[diffusers_name] = state_dict.pop(key)
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+ elif "mlp" in diffusers_name:
+ # Be aware that this is the new diffusers convention and the rest of the code might
+ # not utilize it yet.
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
+ te_state_dict[diffusers_name] = state_dict.pop(key)
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+
+ # (sayakpaul): Duplicate code. Needs to be cleaned.
+ elif lora_name.startswith("lora_te1_"):
+ diffusers_name = key.replace("lora_te1_", "").replace("_", ".")
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
+ if "self_attn" in diffusers_name:
+ te_state_dict[diffusers_name] = state_dict.pop(key)
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+ elif "mlp" in diffusers_name:
+ # Be aware that this is the new diffusers convention and the rest of the code might
+ # not utilize it yet.
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
+ te_state_dict[diffusers_name] = state_dict.pop(key)
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+
+ # (sayakpaul): Duplicate code. Needs to be cleaned.
+ elif lora_name.startswith("lora_te2_"):
+ diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
+ if "self_attn" in diffusers_name:
+ te2_state_dict[diffusers_name] = state_dict.pop(key)
+ te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+ elif "mlp" in diffusers_name:
+ # Be aware that this is the new diffusers convention and the rest of the code might
+ # not utilize it yet.
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
+ te2_state_dict[diffusers_name] = state_dict.pop(key)
+ te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
+
+ # Rename the alphas so that they can be mapped appropriately.
+ if lora_name_alpha in state_dict:
+ alpha = state_dict.pop(lora_name_alpha).item()
+ if lora_name_alpha.startswith("lora_unet_"):
+ prefix = "unet."
+ elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
+ prefix = "text_encoder."
+ else:
+ prefix = "text_encoder_2."
+ new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
+ network_alphas.update({new_name: alpha})
+
+ if len(state_dict) > 0:
+ raise ValueError(
+ f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}"
+ )
+
+ logger.info("Kohya-style checkpoint detected.")
+ unet_state_dict = {f"{cls.unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
+ te_state_dict = {
+ f"{cls.text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()
+ }
+ te2_state_dict = (
+ {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
+ if len(te2_state_dict) > 0
+ else None
+ )
+ if te2_state_dict is not None:
+ te_state_dict.update(te2_state_dict)
+
+ new_state_dict = {**unet_state_dict, **te_state_dict}
+ return new_state_dict, network_alphas
+
+ def unload_lora_weights(self):
+ """
+ Unloads the LoRA parameters.
+
+ Examples:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
+ >>> pipeline.unload_lora_weights()
+ >>> ...
+ ```
+ """
+ if not USE_PEFT_BACKEND:
+ if version.parse(__version__) > version.parse("0.23"):
+ logger.warn(
+ "You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
+ "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
+ )
+
+ for _, module in self.unet.named_modules():
+ if hasattr(module, "set_lora_layer"):
+ module.set_lora_layer(None)
+ else:
+ recurse_remove_peft_layers(self.unet)
+ if hasattr(self.unet, "peft_config"):
+ del self.unet.peft_config
+
+ # Safe to call the following regardless of LoRA.
+ self._remove_text_encoder_monkey_patch()
+
+ def fuse_lora(
+ self,
+ fuse_unet: bool = True,
+ fuse_text_encoder: bool = True,
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ fuse_unet (`bool`, defaults to `True`): Whether to fuse the UNet LoRA parameters.
+ fuse_text_encoder (`bool`, defaults to `True`):
+ Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ """
+ if fuse_unet or fuse_text_encoder:
+ self.num_fused_loras += 1
+ if self.num_fused_loras > 1:
+ logger.warn(
+ "The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.",
+ )
+
+ if fuse_unet:
+ self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
+
+ if USE_PEFT_BACKEND:
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
+ # TODO(Patrick, Younes): enable "safe" fusing
+ for module in text_encoder.modules():
+ if isinstance(module, BaseTunerLayer):
+ if lora_scale != 1.0:
+ module.scale_layer(lora_scale)
+
+ module.merge()
+
+ else:
+ if version.parse(__version__) > version.parse("0.23"):
+ deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
+
+ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
+ for _, attn_module in text_encoder_attn_modules(text_encoder):
+ if isinstance(attn_module.q_proj, PatchedLoraProjection):
+ attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
+ attn_module.k_proj._fuse_lora(lora_scale, safe_fusing)
+ attn_module.v_proj._fuse_lora(lora_scale, safe_fusing)
+ attn_module.out_proj._fuse_lora(lora_scale, safe_fusing)
+
+ for _, mlp_module in text_encoder_mlp_modules(text_encoder):
+ if isinstance(mlp_module.fc1, PatchedLoraProjection):
+ mlp_module.fc1._fuse_lora(lora_scale, safe_fusing)
+ mlp_module.fc2._fuse_lora(lora_scale, safe_fusing)
+
+ if fuse_text_encoder:
+ if hasattr(self, "text_encoder"):
+ fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing)
+ if hasattr(self, "text_encoder_2"):
+ fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing)
+
+ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ unfuse_text_encoder (`bool`, defaults to `True`):
+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
+ LoRA parameters then it won't have any effect.
+ """
+ if unfuse_unet:
+ if not USE_PEFT_BACKEND:
+ self.unet.unfuse_lora()
+ else:
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ for module in self.unet.modules():
+ if isinstance(module, BaseTunerLayer):
+ module.unmerge()
+
+ if USE_PEFT_BACKEND:
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ def unfuse_text_encoder_lora(text_encoder):
+ for module in text_encoder.modules():
+ if isinstance(module, BaseTunerLayer):
+ module.unmerge()
+
+ else:
+ if version.parse(__version__) > version.parse("0.23"):
+ deprecate("unfuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
+
+ def unfuse_text_encoder_lora(text_encoder):
+ for _, attn_module in text_encoder_attn_modules(text_encoder):
+ if isinstance(attn_module.q_proj, PatchedLoraProjection):
+ attn_module.q_proj._unfuse_lora()
+ attn_module.k_proj._unfuse_lora()
+ attn_module.v_proj._unfuse_lora()
+ attn_module.out_proj._unfuse_lora()
+
+ for _, mlp_module in text_encoder_mlp_modules(text_encoder):
+ if isinstance(mlp_module.fc1, PatchedLoraProjection):
+ mlp_module.fc1._unfuse_lora()
+ mlp_module.fc2._unfuse_lora()
+
+ if unfuse_text_encoder:
+ if hasattr(self, "text_encoder"):
+ unfuse_text_encoder_lora(self.text_encoder)
+ if hasattr(self, "text_encoder_2"):
+ unfuse_text_encoder_lora(self.text_encoder_2)
+
+ self.num_fused_loras -= 1
+
+ def set_adapters_for_text_encoder(
+ self,
+ adapter_names: Union[List[str], str],
+ text_encoder: Optional[PreTrainedModel] = None,
+ text_encoder_weights: List[float] = None,
+ ):
+ """
+ Sets the adapter layers for the text encoder.
+
+ Args:
+ adapter_names (`List[str]` or `str`):
+ The names of the adapters to use.
+ text_encoder (`torch.nn.Module`, *optional*):
+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
+ attribute.
+ text_encoder_weights (`List[float]`, *optional*):
+ The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ def process_weights(adapter_names, weights):
+ if weights is None:
+ weights = [1.0] * len(adapter_names)
+ elif isinstance(weights, float):
+ weights = [weights]
+
+ if len(adapter_names) != len(weights):
+ raise ValueError(
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
+ )
+ return weights
+
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
+ text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
+ text_encoder = text_encoder or getattr(self, "text_encoder", None)
+ if text_encoder is None:
+ raise ValueError(
+ "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
+ )
+ set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
+
+ def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
+ """
+ Disables the LoRA layers for the text encoder.
+
+ Args:
+ text_encoder (`torch.nn.Module`, *optional*):
+ The text encoder module to disable the LoRA layers for. If `None`, it will try to get the
+ `text_encoder` attribute.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ text_encoder = text_encoder or getattr(self, "text_encoder", None)
+ if text_encoder is None:
+ raise ValueError("Text Encoder not found.")
+ set_adapter_layers(text_encoder, enabled=False)
+
+ def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
+ """
+ Enables the LoRA layers for the text encoder.
+
+ Args:
+ text_encoder (`torch.nn.Module`, *optional*):
+ The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
+ attribute.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+ text_encoder = text_encoder or getattr(self, "text_encoder", None)
+ if text_encoder is None:
+ raise ValueError("Text Encoder not found.")
+ set_adapter_layers(self.text_encoder, enabled=True)
+
+ def set_adapters(
+ self,
+ adapter_names: Union[List[str], str],
+ adapter_weights: Optional[List[float]] = None,
+ ):
+ # Handle the UNET
+ self.unet.set_adapters(adapter_names, adapter_weights)
+
+ # Handle the Text Encoder
+ if hasattr(self, "text_encoder"):
+ self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
+ if hasattr(self, "text_encoder_2"):
+ self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
+
+ def disable_lora(self):
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ # Disable unet adapters
+ self.unet.disable_lora()
+
+ # Disable text encoder adapters
+ if hasattr(self, "text_encoder"):
+ self.disable_lora_for_text_encoder(self.text_encoder)
+ if hasattr(self, "text_encoder_2"):
+ self.disable_lora_for_text_encoder(self.text_encoder_2)
+
+ def enable_lora(self):
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ # Enable unet adapters
+ self.unet.enable_lora()
+
+ # Enable text encoder adapters
+ if hasattr(self, "text_encoder"):
+ self.enable_lora_for_text_encoder(self.text_encoder)
+ if hasattr(self, "text_encoder_2"):
+ self.enable_lora_for_text_encoder(self.text_encoder_2)
+
+ def get_active_adapters(self) -> List[str]:
+ """
+ Gets the list of the current active adapters.
+
+ Example:
+
+ ```python
+ from diffusers import DiffusionPipeline
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ ).to("cuda")
+ pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
+ pipeline.get_active_adapters()
+ ```
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError(
+ "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
+ )
+
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ active_adapters = []
+
+ for module in self.unet.modules():
+ if isinstance(module, BaseTunerLayer):
+ active_adapters = module.active_adapters
+ break
+
+ return active_adapters
+
+ def get_list_adapters(self) -> Dict[str, List[str]]:
+ """
+ Gets the current list of all available adapters in the pipeline.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError(
+ "PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
+ )
+
+ set_adapters = {}
+
+ if hasattr(self, "text_encoder") and hasattr(self.text_encoder, "peft_config"):
+ set_adapters["text_encoder"] = list(self.text_encoder.peft_config.keys())
+
+ if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
+ set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())
+
+ if hasattr(self, "unet") and hasattr(self.unet, "peft_config"):
+ set_adapters["unet"] = list(self.unet.peft_config.keys())
+
+ return set_adapters
+
+ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
+ """
+ Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
+ you want to load multiple adapters and free some GPU memory.
+
+ Args:
+ adapter_names (`List[str]`):
+ List of adapters to send device to.
+ device (`Union[torch.device, str, int]`):
+ Device to send the adapters to. Can be either a torch device, a str or an integer.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ # Handle the UNET
+ for unet_module in self.unet.modules():
+ if isinstance(unet_module, BaseTunerLayer):
+ for adapter_name in adapter_names:
+ unet_module.lora_A[adapter_name].to(device)
+ unet_module.lora_B[adapter_name].to(device)
+
+ # Handle the text encoder
+ modules_to_process = []
+ if hasattr(self, "text_encoder"):
+ modules_to_process.append(self.text_encoder)
+
+ if hasattr(self, "text_encoder_2"):
+ modules_to_process.append(self.text_encoder_2)
+
+ for text_encoder in modules_to_process:
+ # loop over submodules
+ for text_encoder_module in text_encoder.modules():
+ if isinstance(text_encoder_module, BaseTunerLayer):
+ for adapter_name in adapter_names:
+ text_encoder_module.lora_A[adapter_name].to(device)
+ text_encoder_module.lora_B[adapter_name].to(device)
+
+
+class FromSingleFileMixin:
+ """
+ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
+ """
+
+ @classmethod
+ def from_ckpt(cls, *args, **kwargs):
+ deprecation_message = "The function `from_ckpt` is deprecated in favor of `from_single_file` and will be removed in diffusers v.0.21. Please make sure to use `StableDiffusionPipeline.from_single_file(...)` instead."
+ deprecate("from_ckpt", "0.21.0", deprecation_message, standard_warn=False)
+ return cls.from_single_file(*args, **kwargs)
+
+ @classmethod
+ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
+ r"""
+ Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
+ format. The pipeline is set in evaluation mode (`model.eval()`) by default.
+
+ Parameters:
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+ - A link to the `.ckpt` file (for example
+ `"https://huggingface.co//blob/main/.ckpt"`) on the Hub.
+ - A path to a *file* containing all pipeline weights.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ extract_ema (`bool`, *optional*, defaults to `False`):
+ Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
+ higher quality images for inference. Non-EMA weights are usually better for continuing finetuning.
+ upcast_attention (`bool`, *optional*, defaults to `None`):
+ Whether the attention computation should always be upcasted.
+ image_size (`int`, *optional*, defaults to 512):
+ The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
+ Diffusion v2 base model. Use 768 for Stable Diffusion v2.
+ prediction_type (`str`, *optional*):
+ The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
+ the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
+ num_in_channels (`int`, *optional*, defaults to `None`):
+ The number of input channels. If `None`, it is automatically inferred.
+ scheduler_type (`str`, *optional*, defaults to `"pndm"`):
+ Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
+ "ddim"]`.
+ load_safety_checker (`bool`, *optional*, defaults to `True`):
+ Whether to load the safety checker or not.
+ text_encoder ([`~transformers.CLIPTextModel`], *optional*, defaults to `None`):
+ An instance of `CLIPTextModel` to use, specifically the
+ [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this
+ parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed.
+ vae (`AutoencoderKL`, *optional*, defaults to `None`):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
+ this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
+ tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
+ An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
+ of `CLIPTokenizer` by itself if needed.
+ original_config_file (`str`):
+ Path to `.yaml` config file corresponding to the original architecture. If `None`, will be
+ automatically inferred by looking for a key that only exists in SD2.0 models.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (for example the pipeline components of the
+ specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
+ method. See example below for more information.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> # Download pipeline from huggingface.co and cache.
+ >>> pipeline = StableDiffusionPipeline.from_single_file(
+ ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
+ ... )
+
+ >>> # Download pipeline from local file
+ >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
+ >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly")
+
+ >>> # Enable float16 and move to GPU
+ >>> pipeline = StableDiffusionPipeline.from_single_file(
+ ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
+ ... torch_dtype=torch.float16,
+ ... )
+ >>> pipeline.to("cuda")
+ ```
+ """
+ # import here to avoid circular dependency
+ from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
+
+ original_config_file = kwargs.pop("original_config_file", None)
+ config_files = kwargs.pop("config_files", None)
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ extract_ema = kwargs.pop("extract_ema", False)
+ image_size = kwargs.pop("image_size", None)
+ scheduler_type = kwargs.pop("scheduler_type", "pndm")
+ num_in_channels = kwargs.pop("num_in_channels", None)
+ upcast_attention = kwargs.pop("upcast_attention", None)
+ load_safety_checker = kwargs.pop("load_safety_checker", True)
+ prediction_type = kwargs.pop("prediction_type", None)
+ text_encoder = kwargs.pop("text_encoder", None)
+ vae = kwargs.pop("vae", None)
+ controlnet = kwargs.pop("controlnet", None)
+ tokenizer = kwargs.pop("tokenizer", None)
+
+ torch_dtype = kwargs.pop("torch_dtype", None)
+
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ pipeline_name = cls.__name__
+ file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
+ from_safetensors = file_extension == "safetensors"
+
+ if from_safetensors and use_safetensors is False:
+ raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
+
+ # TODO: For now we only support stable diffusion
+ stable_unclip = None
+ model_type = None
+
+ if pipeline_name in [
+ "StableDiffusionControlNetPipeline",
+ "StableDiffusionControlNetImg2ImgPipeline",
+ "StableDiffusionControlNetInpaintPipeline",
+ ]:
+ from .models.controlnet import ControlNetModel
+ from .pipelines.controlnet.multicontrolnet import MultiControlNetModel
+
+ # list/tuple or a single instance of ControlNetModel or MultiControlNetModel
+ if not (
+ isinstance(controlnet, (ControlNetModel, MultiControlNetModel))
+ or isinstance(controlnet, (list, tuple))
+ and isinstance(controlnet[0], ControlNetModel)
+ ):
+ raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.")
+ elif "StableDiffusion" in pipeline_name:
+ # Model type will be inferred from the checkpoint.
+ pass
+ elif pipeline_name == "StableUnCLIPPipeline":
+ model_type = "FrozenOpenCLIPEmbedder"
+ stable_unclip = "txt2img"
+ elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
+ model_type = "FrozenOpenCLIPEmbedder"
+ stable_unclip = "img2img"
+ elif pipeline_name == "PaintByExamplePipeline":
+ model_type = "PaintByExample"
+ elif pipeline_name == "LDMTextToImagePipeline":
+ model_type = "LDMTextToImage"
+ else:
+ raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
+
+ # remove huggingface url
+ has_valid_url_prefix = False
+ valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
+ for prefix in valid_url_prefixes:
+ if pretrained_model_link_or_path.startswith(prefix):
+ pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
+ has_valid_url_prefix = True
+
+ # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
+ ckpt_path = Path(pretrained_model_link_or_path)
+ if not ckpt_path.is_file():
+ if not has_valid_url_prefix:
+ raise ValueError(
+ f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}"
+ )
+
+ # get repo_id and (potentially nested) file path of ckpt in repo
+ repo_id = "/".join(ckpt_path.parts[:2])
+ file_path = "/".join(ckpt_path.parts[2:])
+
+ if file_path.startswith("blob/"):
+ file_path = file_path[len("blob/") :]
+
+ if file_path.startswith("main/"):
+ file_path = file_path[len("main/") :]
+
+ pretrained_model_link_or_path = hf_hub_download(
+ repo_id,
+ filename=file_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ force_download=force_download,
+ )
+
+ pipe = download_from_original_stable_diffusion_ckpt(
+ pretrained_model_link_or_path,
+ pipeline_class=cls,
+ model_type=model_type,
+ stable_unclip=stable_unclip,
+ controlnet=controlnet,
+ from_safetensors=from_safetensors,
+ extract_ema=extract_ema,
+ image_size=image_size,
+ scheduler_type=scheduler_type,
+ num_in_channels=num_in_channels,
+ upcast_attention=upcast_attention,
+ load_safety_checker=load_safety_checker,
+ prediction_type=prediction_type,
+ text_encoder=text_encoder,
+ vae=vae,
+ tokenizer=tokenizer,
+ original_config_file=original_config_file,
+ config_files=config_files,
+ )
+
+ if torch_dtype is not None:
+ pipe.to(torch_dtype=torch_dtype)
+
+ return pipe
+
+
+class FromOriginalVAEMixin:
+ @classmethod
+ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
+ r"""
+ Instantiate a [`AutoencoderKL`] from pretrained controlnet weights saved in the original `.ckpt` or
+ `.safetensors` format. The pipeline is format. The pipeline is set in evaluation mode (`model.eval()`) by
+ default.
+
+ Parameters:
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+ - A link to the `.ckpt` file (for example
+ `"https://huggingface.co//blob/main/.ckpt"`) on the Hub.
+ - A path to a *file* containing all pipeline weights.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to True, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ image_size (`int`, *optional*, defaults to 512):
+ The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
+ Diffusion v2 base model. Use 768 for Stable Diffusion v2.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ upcast_attention (`bool`, *optional*, defaults to `None`):
+ Whether the attention computation should always be upcasted.
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
+ = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
+ Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (for example the pipeline components of the
+ specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
+ method. See example below for more information.
+
+
+
+ Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you want to load
+ a VAE that does accompany a stable diffusion model of v2 or higher or SDXL.
+
+
+
+ Examples:
+
+ ```py
+ from diffusers import AutoencoderKL
+
+ url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
+ model = AutoencoderKL.from_single_file(url)
+ ```
+ """
+ if not is_omegaconf_available():
+ raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
+
+ from omegaconf import OmegaConf
+
+ from .models import AutoencoderKL
+
+ # import here to avoid circular dependency
+ from .pipelines.stable_diffusion.convert_from_ckpt import (
+ convert_ldm_vae_checkpoint,
+ create_vae_diffusers_config,
+ )
+
+ config_file = kwargs.pop("config_file", None)
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ image_size = kwargs.pop("image_size", None)
+ scaling_factor = kwargs.pop("scaling_factor", None)
+ kwargs.pop("upcast_attention", None)
+
+ torch_dtype = kwargs.pop("torch_dtype", None)
+
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
+ from_safetensors = file_extension == "safetensors"
+
+ if from_safetensors and use_safetensors is False:
+ raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
+
+ # remove huggingface url
+ for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
+ if pretrained_model_link_or_path.startswith(prefix):
+ pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
+
+ # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
+ ckpt_path = Path(pretrained_model_link_or_path)
+ if not ckpt_path.is_file():
+ # get repo_id and (potentially nested) file path of ckpt in repo
+ repo_id = "/".join(ckpt_path.parts[:2])
+ file_path = "/".join(ckpt_path.parts[2:])
+
+ if file_path.startswith("blob/"):
+ file_path = file_path[len("blob/") :]
+
+ if file_path.startswith("main/"):
+ file_path = file_path[len("main/") :]
+
+ pretrained_model_link_or_path = hf_hub_download(
+ repo_id,
+ filename=file_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ force_download=force_download,
+ )
+
+ if from_safetensors:
+ from safetensors import safe_open
+
+ checkpoint = {}
+ with safe_open(pretrained_model_link_or_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ checkpoint[key] = f.get_tensor(key)
+ else:
+ checkpoint = torch.load(pretrained_model_link_or_path, map_location="cpu")
+
+ if "state_dict" in checkpoint:
+ checkpoint = checkpoint["state_dict"]
+
+ if config_file is None:
+ config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
+ config_file = BytesIO(requests.get(config_url).content)
+
+ original_config = OmegaConf.load(config_file)
+
+ # default to sd-v1-5
+ image_size = image_size or 512
+
+ vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
+
+ if scaling_factor is None:
+ if (
+ "model" in original_config
+ and "params" in original_config.model
+ and "scale_factor" in original_config.model.params
+ ):
+ vae_scaling_factor = original_config.model.params.scale_factor
+ else:
+ vae_scaling_factor = 0.18215 # default SD scaling factor
+
+ vae_config["scaling_factor"] = vae_scaling_factor
+
+ ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ with ctx():
+ vae = AutoencoderKL(**vae_config)
+
+ if is_accelerate_available():
+ load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu")
+ else:
+ vae.load_state_dict(converted_vae_checkpoint)
+
+ if torch_dtype is not None:
+ vae.to(dtype=torch_dtype)
+
+ return vae
+
+
+class FromOriginalControlnetMixin:
+ @classmethod
+ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
+ r"""
+ Instantiate a [`ControlNetModel`] from pretrained controlnet weights saved in the original `.ckpt` or
+ `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
+
+ Parameters:
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+ - A link to the `.ckpt` file (for example
+ `"https://huggingface.co//blob/main/.ckpt"`) on the Hub.
+ - A path to a *file* containing all pipeline weights.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
+ dtype is automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
+ incompletely downloaded files are deleted.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to True, the model
+ won't be downloaded from the Hub.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
+ weights. If set to `False`, safetensors weights are not loaded.
+ image_size (`int`, *optional*, defaults to 512):
+ The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
+ Diffusion v2 base model. Use 768 for Stable Diffusion v2.
+ upcast_attention (`bool`, *optional*, defaults to `None`):
+ Whether the attention computation should always be upcasted.
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load and saveable variables (for example the pipeline components of the
+ specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
+ method. See example below for more information.
+
+ Examples:
+
+ ```py
+ from diffusers import StableDiffusionControlnetPipeline, ControlNetModel
+
+ url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
+ model = ControlNetModel.from_single_file(url)
+
+ url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
+ pipe = StableDiffusionControlnetPipeline.from_single_file(url, controlnet=controlnet)
+ ```
+ """
+ # import here to avoid circular dependency
+ from .pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
+
+ config_file = kwargs.pop("config_file", None)
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ num_in_channels = kwargs.pop("num_in_channels", None)
+ use_linear_projection = kwargs.pop("use_linear_projection", None)
+ revision = kwargs.pop("revision", None)
+ extract_ema = kwargs.pop("extract_ema", False)
+ image_size = kwargs.pop("image_size", None)
+ upcast_attention = kwargs.pop("upcast_attention", None)
+
+ torch_dtype = kwargs.pop("torch_dtype", None)
+
+ use_safetensors = kwargs.pop("use_safetensors", None)
+
+ file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
+ from_safetensors = file_extension == "safetensors"
+
+ if from_safetensors and use_safetensors is False:
+ raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
+
+ # remove huggingface url
+ for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
+ if pretrained_model_link_or_path.startswith(prefix):
+ pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
+
+ # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
+ ckpt_path = Path(pretrained_model_link_or_path)
+ if not ckpt_path.is_file():
+ # get repo_id and (potentially nested) file path of ckpt in repo
+ repo_id = "/".join(ckpt_path.parts[:2])
+ file_path = "/".join(ckpt_path.parts[2:])
+
+ if file_path.startswith("blob/"):
+ file_path = file_path[len("blob/") :]
+
+ if file_path.startswith("main/"):
+ file_path = file_path[len("main/") :]
+
+ pretrained_model_link_or_path = hf_hub_download(
+ repo_id,
+ filename=file_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ force_download=force_download,
+ )
+
+ if config_file is None:
+ config_url = "https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml"
+ config_file = BytesIO(requests.get(config_url).content)
+
+ image_size = image_size or 512
+
+ controlnet = download_controlnet_from_original_ckpt(
+ pretrained_model_link_or_path,
+ original_config_file=config_file,
+ image_size=image_size,
+ extract_ema=extract_ema,
+ num_in_channels=num_in_channels,
+ upcast_attention=upcast_attention,
+ from_safetensors=from_safetensors,
+ use_linear_projection=use_linear_projection,
+ )
+
+ if torch_dtype is not None:
+ controlnet.to(torch_dtype=torch_dtype)
+
+ return controlnet
+
+
+class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
+ """This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
+
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
+ `self.text_encoder`.
+
+ All kwargs are forwarded to `self.lora_state_dict`.
+
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+
+ See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into
+ `self.unet`.
+
+ See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded
+ into `self.text_encoder`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.LoraLoaderMixin.lora_state_dict`].
+ """
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
+ # it here explicitly to be able to tell that it's coming from an SDXL
+ # pipeline.
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ state_dict, network_alphas = self.lora_state_dict(
+ pretrained_model_name_or_path_or_dict,
+ unet_config=self.unet.config,
+ **kwargs,
+ )
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_unet(
+ state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
+ )
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
+ if len(text_encoder_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder,
+ prefix="text_encoder",
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
+ if len(text_encoder_2_state_dict) > 0:
+ self.load_lora_into_text_encoder(
+ text_encoder_2_state_dict,
+ network_alphas=network_alphas,
+ text_encoder=self.text_encoder_2,
+ prefix="text_encoder_2",
+ lora_scale=self.lora_scale,
+ adapter_name=adapter_name,
+ _pipeline=self,
+ )
+
+ @classmethod
+ def save_lora_weights(
+ self,
+ save_directory: Union[str, os.PathLike],
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the UNet and text encoder.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `unet`.
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
+ encoder LoRA state dict because it comes from 🤗 Transformers.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ """
+ state_dict = {}
+
+ def pack_weights(layers, prefix):
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
+ return layers_state_dict
+
+ if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
+ raise ValueError(
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
+ )
+
+ if unet_lora_layers:
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
+
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+
+ self.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def _remove_text_encoder_monkey_patch(self):
+ if USE_PEFT_BACKEND:
+ recurse_remove_peft_layers(self.text_encoder)
+ # TODO: @younesbelkada handle this in transformers side
+ if getattr(self.text_encoder, "peft_config", None) is not None:
+ del self.text_encoder.peft_config
+ self.text_encoder._hf_peft_config_loaded = None
+
+ recurse_remove_peft_layers(self.text_encoder_2)
+ if getattr(self.text_encoder_2, "peft_config", None) is not None:
+ del self.text_encoder_2.peft_config
+ self.text_encoder_2._hf_peft_config_loaded = None
+ else:
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 839045001bb0..a5d0066d5c40 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -14,12 +14,7 @@
from typing import TYPE_CHECKING
-from ..utils import (
- DIFFUSERS_SLOW_IMPORT,
- _LazyModule,
- is_flax_available,
- is_torch_available,
-)
+from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
_import_structure = {}
@@ -28,9 +23,7 @@
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
- _import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
- _import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["modeling_utils"] = ["ModelMixin"]
@@ -42,9 +35,6 @@
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
- _import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
- _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
- _import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["vq_model"] = ["VQModel"]
if is_flax_available():
@@ -58,9 +48,7 @@
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
- from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
- from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel
from .modeling_utils import ModelMixin
@@ -72,9 +60,6 @@
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
- from .unet_kandi3 import Kandinsky3UNet
- from .unet_motion_model import MotionAdapter, UNetMotionModel
- from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .vq_model import VQModel
if is_flax_available():
diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py
index 8b75162ba597..46da899096c2 100644
--- a/src/diffusers/models/activations.py
+++ b/src/diffusers/models/activations.py
@@ -1,34 +1,5 @@
-# coding=utf-8
-# Copyright 2023 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import torch
-import torch.nn.functional as F
from torch import nn
-from ..utils import USE_PEFT_BACKEND
-from .lora import LoRACompatibleLinear
-
-
-ACTIVATION_FUNCTIONS = {
- "swish": nn.SiLU(),
- "silu": nn.SiLU(),
- "mish": nn.Mish(),
- "gelu": nn.GELU(),
- "relu": nn.ReLU(),
-}
-
def get_activation(act_fn: str) -> nn.Module:
"""Helper function to get activation function from string.
@@ -39,82 +10,13 @@ def get_activation(act_fn: str) -> nn.Module:
Returns:
nn.Module: Activation function.
"""
-
- act_fn = act_fn.lower()
- if act_fn in ACTIVATION_FUNCTIONS:
- return ACTIVATION_FUNCTIONS[act_fn]
+ if act_fn in ["swish", "silu"]:
+ return nn.SiLU()
+ elif act_fn == "mish":
+ return nn.Mish()
+ elif act_fn == "gelu":
+ return nn.GELU()
+ elif act_fn == "relu":
+ return nn.ReLU()
else:
raise ValueError(f"Unsupported activation function: {act_fn}")
-
-
-class GELU(nn.Module):
- r"""
- GELU activation function with tanh approximation support with `approximate="tanh"`.
-
- Parameters:
- dim_in (`int`): The number of channels in the input.
- dim_out (`int`): The number of channels in the output.
- approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
- """
-
- def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out)
- self.approximate = approximate
-
- def gelu(self, gate: torch.Tensor) -> torch.Tensor:
- if gate.device.type != "mps":
- return F.gelu(gate, approximate=self.approximate)
- # mps: gelu is not implemented for float16
- return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
-
- def forward(self, hidden_states):
- hidden_states = self.proj(hidden_states)
- hidden_states = self.gelu(hidden_states)
- return hidden_states
-
-
-class GEGLU(nn.Module):
- r"""
- A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
-
- Parameters:
- dim_in (`int`): The number of channels in the input.
- dim_out (`int`): The number of channels in the output.
- """
-
- def __init__(self, dim_in: int, dim_out: int):
- super().__init__()
- linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
-
- self.proj = linear_cls(dim_in, dim_out * 2)
-
- def gelu(self, gate: torch.Tensor) -> torch.Tensor:
- if gate.device.type != "mps":
- return F.gelu(gate)
- # mps: gelu is not implemented for float16
- return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
-
- def forward(self, hidden_states, scale: float = 1.0):
- args = () if USE_PEFT_BACKEND else (scale,)
- hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
- return hidden_states * self.gelu(gate)
-
-
-class ApproximateGELU(nn.Module):
- r"""
- The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
- [paper](https://arxiv.org/abs/1606.08415).
-
- Parameters:
- dim_in (`int`): The number of channels in the input.
- dim_out (`int`): The number of channels in the output.
- """
-
- def __init__(self, dim_in: int, dim_out: int):
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.proj(x)
- return x * torch.sigmoid(1.702 * x)
diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py
index 0f4b2ec03371..64d64d07bf77 100644
--- a/src/diffusers/models/adapter.py
+++ b/src/diffusers/models/adapter.py
@@ -20,6 +20,7 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .modeling_utils import ModelMixin
+from .resnet import Downsample2D
logger = logging.get_logger(__name__)
@@ -50,28 +51,24 @@ def __init__(self, adapters: List["T2IAdapter"]):
if len(adapters) == 1:
raise ValueError("For a single adapter, please use the `T2IAdapter` class instead of `MultiAdapter`")
- # The outputs from each adapter are added together with a weight.
- # This means that the change in dimensions from downsampling must
- # be the same for all adapters. Inductively, it also means the
- # downscale_factor and total_downscale_factor must be the same for all
- # adapters.
+ # The outputs from each adapter are added together with a weight
+ # This means that the change in dimenstions from downsampling must
+ # be the same for all adapters. Inductively, it also means the total
+ # downscale factor must also be the same for all adapters.
+
first_adapter_total_downscale_factor = adapters[0].total_downscale_factor
- first_adapter_downscale_factor = adapters[0].downscale_factor
+
for idx in range(1, len(adapters)):
- if (
- adapters[idx].total_downscale_factor != first_adapter_total_downscale_factor
- or adapters[idx].downscale_factor != first_adapter_downscale_factor
- ):
+ adapter_idx_total_downscale_factor = adapters[idx].total_downscale_factor
+
+ if adapter_idx_total_downscale_factor != first_adapter_total_downscale_factor:
raise ValueError(
- f"Expecting all adapters to have the same downscaling behavior, but got:\n"
- f"adapters[0].total_downscale_factor={first_adapter_total_downscale_factor}\n"
- f"adapters[0].downscale_factor={first_adapter_downscale_factor}\n"
- f"adapter[`{idx}`].total_downscale_factor={adapters[idx].total_downscale_factor}\n"
- f"adapter[`{idx}`].downscale_factor={adapters[idx].downscale_factor}"
+ f"Expecting all adapters to have the same total_downscale_factor, "
+ f"but got adapters[0].total_downscale_factor={first_adapter_total_downscale_factor} and "
+ f"adapter[`{idx}`]={adapter_idx_total_downscale_factor}"
)
- self.total_downscale_factor = first_adapter_total_downscale_factor
- self.downscale_factor = first_adapter_downscale_factor
+ self.total_downscale_factor = adapters[0].total_downscale_factor
def forward(self, xs: torch.Tensor, adapter_weights: Optional[List[float]] = None) -> List[torch.Tensor]:
r"""
@@ -277,13 +274,6 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
def total_downscale_factor(self):
return self.adapter.total_downscale_factor
- @property
- def downscale_factor(self):
- """The downscale factor applied in the T2I-Adapter's initial pixel unshuffle operation. If an input image's dimensions are
- not evenly divisible by the downscale_factor then an exception will be raised.
- """
- return self.adapter.unshuffle.downscale_factor
-
# full adapter
@@ -409,7 +399,7 @@ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, dow
self.downsample = None
if down:
- self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
+ self.downsample = Downsample2D(in_channels)
self.in_conv = None
if in_channels != out_channels:
@@ -456,8 +446,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
layer on the input tensor. It returns addition with the input tensor.
"""
-
- h = self.act(self.block1(x))
+ h = x
+ h = self.block1(h)
+ h = self.act(h)
h = self.block2(h)
return h + x
@@ -535,7 +526,7 @@ def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, dow
self.downsample = None
if down:
- self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True)
+ self.downsample = Downsample2D(in_channels)
self.in_conv = nn.Conv2d(in_channels, mid_channels, kernel_size=1)
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
@@ -577,8 +568,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
another convolutional layer and adds it to input tensor.
"""
-
- h = self.act(self.block1(x))
+ h = x
+ h = self.block1(h)
+ h = self.act(h)
h = self.block2(h)
return h + x
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index f02b5e249eee..47608005d374 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -11,43 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Tuple
import torch
+import torch.nn.functional as F
from torch import nn
from ..utils import USE_PEFT_BACKEND
from ..utils.torch_utils import maybe_allow_in_graph
-from .activations import GEGLU, GELU, ApproximateGELU
+from .activations import get_activation
from .attention_processor import Attention
-from .embeddings import SinusoidalPositionalEmbedding
+from .embeddings import CombinedTimestepLabelEmbeddings
from .lora import LoRACompatibleLinear
-from .normalization import AdaLayerNorm, AdaLayerNormZero
-
-
-def _chunked_feed_forward(
- ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
-):
- # "feed_forward_chunk_size" can be used to save memory
- if hidden_states.shape[chunk_dim] % chunk_size != 0:
- raise ValueError(
- f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
- )
-
- num_chunks = hidden_states.shape[chunk_dim] // chunk_size
- if lora_scale is None:
- ff_output = torch.cat(
- [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
- dim=chunk_dim,
- )
- else:
- # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
- ff_output = torch.cat(
- [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
- dim=chunk_dim,
- )
-
- return ff_output
@maybe_allow_in_graph
@@ -122,10 +97,6 @@ class BasicTransformerBlock(nn.Module):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
- positional_embeddings (`str`, *optional*, defaults to `None`):
- The type of positional embeddings to apply to.
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
- The maximum number of positional embeddings to apply.
"""
def __init__(
@@ -142,20 +113,15 @@ def __init__(
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
- norm_eps: float = 1e-5,
+ norm_type: str = "layer_norm",
final_dropout: bool = False,
attention_type: str = "default",
- positional_embeddings: Optional[str] = None,
- num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
- self.use_layer_norm = norm_type == "layer_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
@@ -163,16 +129,6 @@ def __init__(
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
- if positional_embeddings and (num_positional_embeddings is None):
- raise ValueError(
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
- )
-
- if positional_embeddings == "sinusoidal":
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
- else:
- self.pos_embed = None
-
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
@@ -180,8 +136,7 @@ def __init__(
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
-
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
@@ -200,7 +155,7 @@ def __init__(
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
)
self.attn2 = Attention(
query_dim=dim,
@@ -216,29 +171,18 @@ def __init__(
self.attn2 = None
# 3. Feed-forward
- if not self.use_ada_layer_norm_single:
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
-
- self.ff = FeedForward(
- dim,
- dropout=dropout,
- activation_fn=activation_fn,
- final_dropout=final_dropout,
- )
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
- # 5. Scale-shift for PixArt-Alpha.
- if self.use_ada_layer_norm_single:
- self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
-
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
@@ -255,28 +199,14 @@ def forward(
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
- batch_size = hidden_states.shape[0]
-
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
- elif self.use_layer_norm:
- norm_hidden_states = self.norm1(hidden_states)
- elif self.use_ada_layer_norm_single:
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
- self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
- ).chunk(6, dim=1)
- norm_hidden_states = self.norm1(hidden_states)
- norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
- norm_hidden_states = norm_hidden_states.squeeze(1)
else:
- raise ValueError("Incorrect norm used")
-
- if self.pos_embed is not None:
- norm_hidden_states = self.pos_embed(norm_hidden_states)
+ norm_hidden_states = self.norm1(hidden_states)
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
@@ -293,32 +223,18 @@ def forward(
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
- elif self.use_ada_layer_norm_single:
- attn_output = gate_msa * attn_output
-
hidden_states = attn_output + hidden_states
- if hidden_states.ndim == 4:
- hidden_states = hidden_states.squeeze(1)
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+ # 2.5 ends
# 3. Cross-Attention
if self.attn2 is not None:
- if self.use_ada_layer_norm:
- norm_hidden_states = self.norm2(hidden_states, timestep)
- elif self.use_ada_layer_norm_zero or self.use_layer_norm:
- norm_hidden_states = self.norm2(hidden_states)
- elif self.use_ada_layer_norm_single:
- # For PixArt norm2 isn't applied here:
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
- norm_hidden_states = hidden_states
- else:
- raise ValueError("Incorrect norm")
-
- if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
- norm_hidden_states = self.pos_embed(norm_hidden_states)
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
attn_output = self.attn2(
norm_hidden_states,
@@ -329,163 +245,33 @@ def forward(
hidden_states = attn_output + hidden_states
# 4. Feed-forward
- if not self.use_ada_layer_norm_single:
- norm_hidden_states = self.norm3(hidden_states)
+ norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
- if self.use_ada_layer_norm_single:
- norm_hidden_states = self.norm2(hidden_states)
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
-
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
- ff_output = _chunked_feed_forward(
- self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+ ff_output = torch.cat(
+ [
+ self.ff(hid_slice, scale=lora_scale)
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
+ ],
+ dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
- elif self.use_ada_layer_norm_single:
- ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
- if hidden_states.ndim == 4:
- hidden_states = hidden_states.squeeze(1)
-
- return hidden_states
-
-
-@maybe_allow_in_graph
-class TemporalBasicTransformerBlock(nn.Module):
- r"""
- A basic Transformer block for video like data.
-
- Parameters:
- dim (`int`): The number of channels in the input and output.
- time_mix_inner_dim (`int`): The number of channels for temporal attention.
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
- attention_head_dim (`int`): The number of channels in each head.
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
- """
-
- def __init__(
- self,
- dim: int,
- time_mix_inner_dim: int,
- num_attention_heads: int,
- attention_head_dim: int,
- cross_attention_dim: Optional[int] = None,
- ):
- super().__init__()
- self.is_res = dim == time_mix_inner_dim
-
- self.norm_in = nn.LayerNorm(dim)
-
- # Define 3 blocks. Each block has its own normalization layer.
- # 1. Self-Attn
- self.norm_in = nn.LayerNorm(dim)
- self.ff_in = FeedForward(
- dim,
- dim_out=time_mix_inner_dim,
- activation_fn="geglu",
- )
-
- self.norm1 = nn.LayerNorm(time_mix_inner_dim)
- self.attn1 = Attention(
- query_dim=time_mix_inner_dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- cross_attention_dim=None,
- )
-
- # 2. Cross-Attn
- if cross_attention_dim is not None:
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
- # the second cross attention block.
- self.norm2 = nn.LayerNorm(time_mix_inner_dim)
- self.attn2 = Attention(
- query_dim=time_mix_inner_dim,
- cross_attention_dim=cross_attention_dim,
- heads=num_attention_heads,
- dim_head=attention_head_dim,
- ) # is self-attn if encoder_hidden_states is none
- else:
- self.norm2 = None
- self.attn2 = None
-
- # 3. Feed-forward
- self.norm3 = nn.LayerNorm(time_mix_inner_dim)
- self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
-
- # let chunk size default to None
- self._chunk_size = None
- self._chunk_dim = None
-
- def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
- # Sets chunk feed-forward
- self._chunk_size = chunk_size
- # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
- self._chunk_dim = 1
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- num_frames: int,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
- # Notice that normalization is always applied before the real computation in the following blocks.
- # 0. Self-Attention
- batch_size = hidden_states.shape[0]
-
- batch_frames, seq_length, channels = hidden_states.shape
- batch_size = batch_frames // num_frames
-
- hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
- hidden_states = hidden_states.permute(0, 2, 1, 3)
- hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
-
- residual = hidden_states
- hidden_states = self.norm_in(hidden_states)
-
- if self._chunk_size is not None:
- hidden_states = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size)
- else:
- hidden_states = self.ff_in(hidden_states)
-
- if self.is_res:
- hidden_states = hidden_states + residual
-
- norm_hidden_states = self.norm1(hidden_states)
- attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
- hidden_states = attn_output + hidden_states
-
- # 3. Cross-Attention
- if self.attn2 is not None:
- norm_hidden_states = self.norm2(hidden_states)
- attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
- hidden_states = attn_output + hidden_states
-
- # 4. Feed-forward
- norm_hidden_states = self.norm3(hidden_states)
-
- if self._chunk_size is not None:
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
- else:
- ff_output = self.ff(norm_hidden_states)
-
- if self.is_res:
- hidden_states = ff_output + hidden_states
- else:
- hidden_states = ff_output
-
- hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
- hidden_states = hidden_states.permute(0, 2, 1, 3)
- hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
return hidden_states
@@ -545,3 +331,168 @@ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tens
else:
hidden_states = module(hidden_states)
return hidden_states
+
+
+class GELU(nn.Module):
+ r"""
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+ self.approximate = approximate
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate, approximate=self.approximate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+class GEGLU(nn.Module):
+ r"""
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
+
+ self.proj = linear_cls(dim_in, dim_out * 2)
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states, scale: float = 1.0):
+ args = () if USE_PEFT_BACKEND else (scale,)
+ hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+ r"""
+ The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
+ https://arxiv.org/abs/1606.08415.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
+
+
+class AdaLayerNorm(nn.Module):
+ r"""
+ Norm layer modified to incorporate timestep embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the dictionary of embeddings.
+ """
+
+ def __init__(self, embedding_dim: int, num_embeddings: int):
+ super().__init__()
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
+
+ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
+ emb = self.linear(self.silu(self.emb(timestep)))
+ scale, shift = torch.chunk(emb, 2)
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
+
+class AdaLayerNormZero(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (adaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the dictionary of embeddings.
+ """
+
+ def __init__(self, embedding_dim: int, num_embeddings: int):
+ super().__init__()
+
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timestep: torch.Tensor,
+ class_labels: torch.LongTensor,
+ hidden_dtype: Optional[torch.dtype] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+class AdaGroupNorm(nn.Module):
+ r"""
+ GroupNorm layer modified to incorporate timestep embeddings.
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the dictionary of embeddings.
+ num_groups (`int`): The number of groups to separate the channels into.
+ act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
+ eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
+ """
+
+ def __init__(
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
+ ):
+ super().__init__()
+ self.num_groups = num_groups
+ self.eps = eps
+
+ if act_fn is None:
+ self.act = None
+ else:
+ self.act = get_activation(act_fn)
+
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ if self.act:
+ emb = self.act(emb)
+ emb = self.linear(emb)
+ emb = emb[:, :, None, None]
+ scale, shift = emb.chunk(2, dim=1)
+
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
+ x = x * (1 + scale) + shift
+ return x
diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py
index ccad3f539051..f86028619554 100644
--- a/src/diffusers/models/attention_flax.py
+++ b/src/diffusers/models/attention_flax.py
@@ -110,10 +110,7 @@ def chunk_scanner(chunk_idx, _):
)
_, res = jax.lax.scan(
- f=chunk_scanner,
- init=0,
- xs=None,
- length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
+ f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
)
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
@@ -141,7 +138,6 @@ class FlaxAttention(nn.Module):
Parameters `dtype`
"""
-
query_dim: int
heads: int = 8
dim_head: int = 64
@@ -266,7 +262,6 @@ class FlaxBasicTransformerBlock(nn.Module):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
-
dim: int
n_heads: int
d_head: int
@@ -352,7 +347,6 @@ class FlaxTransformer2DModel(nn.Module):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
"""
-
in_channels: int
n_heads: int
d_head: int
@@ -448,7 +442,6 @@ class FlaxFeedForward(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
-
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
@@ -478,7 +471,6 @@ class FlaxGEGLU(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
-
dim: int
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 21eb3a32dc09..9856f3c7739c 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -16,7 +16,7 @@
import torch
import torch.nn.functional as F
-from torch import einsum, nn
+from torch import nn
from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
@@ -40,50 +40,14 @@ class Attention(nn.Module):
A cross attention layer.
Parameters:
- query_dim (`int`):
- The number of channels in the query.
+ query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
- heads (`int`, *optional*, defaults to 8):
- The number of heads to use for multi-head attention.
- dim_head (`int`, *optional*, defaults to 64):
- The number of channels in each head.
- dropout (`float`, *optional*, defaults to 0.0):
- The dropout probability to use.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
- upcast_attention (`bool`, *optional*, defaults to False):
- Set to `True` to upcast the attention computation to `float32`.
- upcast_softmax (`bool`, *optional*, defaults to False):
- Set to `True` to upcast the softmax computation to `float32`.
- cross_attention_norm (`str`, *optional*, defaults to `None`):
- The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
- cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
- The number of groups to use for the group norm in the cross attention.
- added_kv_proj_dim (`int`, *optional*, defaults to `None`):
- The number of channels to use for the added key and value projections. If `None`, no projection is used.
- norm_num_groups (`int`, *optional*, defaults to `None`):
- The number of groups to use for the group norm in the attention.
- spatial_norm_dim (`int`, *optional*, defaults to `None`):
- The number of channels to use for the spatial normalization.
- out_bias (`bool`, *optional*, defaults to `True`):
- Set to `True` to use a bias in the output linear layer.
- scale_qk (`bool`, *optional*, defaults to `True`):
- Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
- only_cross_attention (`bool`, *optional*, defaults to `False`):
- Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
- `added_kv_proj_dim` is not `None`.
- eps (`float`, *optional*, defaults to 1e-5):
- An additional value added to the denominator in group normalization that is used for numerical stability.
- rescale_output_factor (`float`, *optional*, defaults to 1.0):
- A factor to rescale the output by dividing it with this value.
- residual_connection (`bool`, *optional*, defaults to `False`):
- Set to `True` to add the residual connection to the output.
- _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
- Set to `True` if the attention block is loaded from a deprecated state dict.
- processor (`AttnProcessor`, *optional*, defaults to `None`):
- The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
- `AttnProcessor` otherwise.
"""
def __init__(
@@ -93,7 +57,7 @@ def __init__(
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
- bias: bool = False,
+ bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: Optional[str] = None,
@@ -107,7 +71,7 @@ def __init__(
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
- _from_deprecated_attn_block: bool = False,
+ _from_deprecated_attn_block=False,
processor: Optional["AttnProcessor"] = None,
):
super().__init__()
@@ -208,17 +172,7 @@ def __init__(
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
- ) -> None:
- r"""
- Set whether to use memory efficient attention from `xformers` or not.
-
- Args:
- use_memory_efficient_attention_xformers (`bool`):
- Whether to use memory efficient attention from `xformers` or not.
- attention_op (`Callable`, *optional*):
- The attention operation to use. Defaults to `None` which uses the default attention operation from
- `xformers`.
- """
+ ):
is_lora = hasattr(self, "processor") and isinstance(
self.processor,
LORA_ATTENTION_PROCESSORS,
@@ -340,14 +294,7 @@ def set_use_memory_efficient_attention_xformers(
self.set_processor(processor)
- def set_attention_slice(self, slice_size: int) -> None:
- r"""
- Set the slice size for attention computation.
-
- Args:
- slice_size (`int`):
- The slice size for attention computation.
- """
+ def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
@@ -368,17 +315,8 @@ def set_attention_slice(self, slice_size: int) -> None:
self.set_processor(processor)
- def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False) -> None:
- r"""
- Set the attention processor to use.
-
- Args:
- processor (`AttnProcessor`):
- The attention processor to use.
- _remove_lora (`bool`, *optional*, defaults to `False`):
- Set to `True` to remove LoRA layers from the model.
- """
- if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
+ def set_processor(self, processor: "AttnProcessor", _remove_lora=False):
+ if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
deprecate(
"set_processor to offload LoRA",
"0.26.0",
@@ -404,16 +342,6 @@ def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False)
self.processor = processor
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
- r"""
- Get the attention processor in use.
-
- Args:
- return_deprecated_lora (`bool`, *optional*, defaults to `False`):
- Set to `True` to return the deprecated LoRA attention processor.
-
- Returns:
- "AttentionProcessor": The attention processor in use.
- """
if not return_deprecated_lora:
return self.processor
@@ -493,29 +421,7 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce
return lora_processor
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- **cross_attention_kwargs,
- ) -> torch.Tensor:
- r"""
- The forward method of the `Attention` class.
-
- Args:
- hidden_states (`torch.Tensor`):
- The hidden states of the query.
- encoder_hidden_states (`torch.Tensor`, *optional*):
- The hidden states of the encoder.
- attention_mask (`torch.Tensor`, *optional*):
- The attention mask to use. If `None`, no mask is applied.
- **cross_attention_kwargs:
- Additional keyword arguments to pass along to the cross attention.
-
- Returns:
- `torch.Tensor`: The output of the attention layer.
- """
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
@@ -527,36 +433,14 @@ def forward(
**cross_attention_kwargs,
)
- def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
- r"""
- Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
- is the number of heads initialized while constructing the `Attention` class.
-
- Args:
- tensor (`torch.Tensor`): The tensor to reshape.
-
- Returns:
- `torch.Tensor`: The reshaped tensor.
- """
+ def batch_to_head_dim(self, tensor):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
- def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
- r"""
- Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
- the number of heads initialized while constructing the `Attention` class.
-
- Args:
- tensor (`torch.Tensor`): The tensor to reshape.
- out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
- reshaped to `[batch_size * heads, seq_len, dim // heads]`.
-
- Returns:
- `torch.Tensor`: The reshaped tensor.
- """
+ def head_to_batch_dim(self, tensor, out_dim=3):
head_size = self.heads
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
@@ -567,20 +451,7 @@ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Ten
return tensor
- def get_attention_scores(
- self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
- ) -> torch.Tensor:
- r"""
- Compute the attention scores.
-
- Args:
- query (`torch.Tensor`): The query tensor.
- key (`torch.Tensor`): The key tensor.
- attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
-
- Returns:
- `torch.Tensor`: The attention probabilities/scores.
- """
+ def get_attention_scores(self, query, key, attention_mask=None):
dtype = query.dtype
if self.upcast_attention:
query = query.float()
@@ -614,25 +485,7 @@ def get_attention_scores(
return attention_probs
- def prepare_attention_mask(
- self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
- ) -> torch.Tensor:
- r"""
- Prepare the attention mask for the attention computation.
-
- Args:
- attention_mask (`torch.Tensor`):
- The attention mask to prepare.
- target_length (`int`):
- The target length of the attention mask. This is the length of the attention mask after padding.
- batch_size (`int`):
- The batch size, which is used to repeat the attention mask.
- out_dim (`int`, *optional*, defaults to `3`):
- The output dimension of the attention mask. Can be either `3` or `4`.
-
- Returns:
- `torch.Tensor`: The prepared attention mask.
- """
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size, out_dim=3):
head_size = self.heads
if attention_mask is None:
return attention_mask
@@ -661,17 +514,7 @@ def prepare_attention_mask(
return attention_mask
- def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
- r"""
- Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
- `Attention` class.
-
- Args:
- encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
-
- Returns:
- `torch.Tensor`: The normalized encoder hidden states.
- """
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
if isinstance(self.norm_cross, nn.LayerNorm):
@@ -699,12 +542,12 @@ class AttnProcessor:
def __call__(
self,
attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- temb: Optional[torch.FloatTensor] = None,
- scale: float = 1.0,
- ) -> torch.Tensor:
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ scale=1.0,
+ ):
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
@@ -781,12 +624,12 @@ class CustomDiffusionAttnProcessor(nn.Module):
def __init__(
self,
- train_kv: bool = True,
- train_q_out: bool = True,
- hidden_size: Optional[int] = None,
- cross_attention_dim: Optional[int] = None,
- out_bias: bool = True,
- dropout: float = 0.0,
+ train_kv=True,
+ train_q_out=True,
+ hidden_size=None,
+ cross_attention_dim=None,
+ out_bias=True,
+ dropout=0.0,
):
super().__init__()
self.train_kv = train_kv
@@ -805,13 +648,7 @@ def __init__(
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.Tensor:
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out:
@@ -870,18 +707,8 @@ class AttnAddedKVProcessor:
encoder.
"""
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- scale: float = 1.0,
- ) -> torch.Tensor:
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
-
- args = () if USE_PEFT_BACKEND else (scale,)
-
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
@@ -894,17 +721,17 @@ def __call__(
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states, *args)
+ query = attn.to_q(hidden_states, scale=scale)
query = attn.head_to_batch_dim(query)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
- key = attn.to_k(hidden_states, *args)
- value = attn.to_v(hidden_states, *args)
+ key = attn.to_k(hidden_states, scale=scale)
+ value = attn.to_v(hidden_states, scale=scale)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
@@ -918,7 +745,7 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
- hidden_states = attn.to_out[0](hidden_states, *args)
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -940,18 +767,8 @@ def __init__(self):
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- scale: float = 1.0,
- ) -> torch.Tensor:
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
-
- args = () if USE_PEFT_BACKEND else (scale,)
-
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
@@ -964,7 +781,7 @@ def __call__(
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states, *args)
+ query = attn.to_q(hidden_states, scale=scale)
query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
@@ -973,8 +790,8 @@ def __call__(
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention:
- key = attn.to_k(hidden_states, *args)
- value = attn.to_v(hidden_states, *args)
+ key = attn.to_k(hidden_states, scale=scale)
+ value = attn.to_v(hidden_states, scale=scale)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
@@ -991,7 +808,7 @@ def __call__(
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj
- hidden_states = attn.to_out[0](hidden_states, *args)
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1016,13 +833,7 @@ class XFormersAttnAddedKVProcessor:
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.Tensor:
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
@@ -1095,11 +906,9 @@ def __call__(
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
- ) -> torch.FloatTensor:
+ ):
residual = hidden_states
- args = () if USE_PEFT_BACKEND else (scale,)
-
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1127,15 +936,15 @@ def __call__(
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states, *args)
+ query = attn.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states, *args)
- value = attn.to_v(encoder_hidden_states, *args)
+ key = attn.to_k(encoder_hidden_states, scale=scale)
+ value = attn.to_v(encoder_hidden_states, scale=scale)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
@@ -1148,7 +957,7 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
- hidden_states = attn.to_out[0](hidden_states, *args)
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1175,16 +984,14 @@ def __init__(self):
def __call__(
self,
attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- temb: Optional[torch.FloatTensor] = None,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
scale: float = 1.0,
- ) -> torch.FloatTensor:
+ ):
residual = hidden_states
- args = () if USE_PEFT_BACKEND else (scale,)
-
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1215,8 +1022,12 @@ def __call__(
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = attn.to_k(encoder_hidden_states, *args)
- value = attn.to_v(encoder_hidden_states, *args)
+ key = (
+ attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
+ )
+ value = (
+ attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
+ )
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -1236,7 +1047,9 @@ def __call__(
hidden_states = hidden_states.to(query.dtype)
# linear proj
- hidden_states = attn.to_out[0](hidden_states, *args)
+ hidden_states = (
+ attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
+ )
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1276,12 +1089,12 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
def __init__(
self,
- train_kv: bool = True,
- train_q_out: bool = False,
- hidden_size: Optional[int] = None,
- cross_attention_dim: Optional[int] = None,
- out_bias: bool = True,
- dropout: float = 0.0,
+ train_kv=True,
+ train_q_out=False,
+ hidden_size=None,
+ cross_attention_dim=None,
+ out_bias=True,
+ dropout=0.0,
attention_op: Optional[Callable] = None,
):
super().__init__()
@@ -1302,13 +1115,7 @@ def __init__(
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
@@ -1363,7 +1170,6 @@ def __call__(
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
-
return hidden_states
@@ -1389,12 +1195,12 @@ class CustomDiffusionAttnProcessor2_0(nn.Module):
def __init__(
self,
- train_kv: bool = True,
- train_q_out: bool = True,
- hidden_size: Optional[int] = None,
- cross_attention_dim: Optional[int] = None,
- out_bias: bool = True,
- dropout: float = 0.0,
+ train_kv=True,
+ train_q_out=True,
+ hidden_size=None,
+ cross_attention_dim=None,
+ out_bias=True,
+ dropout=0.0,
):
super().__init__()
self.train_kv = train_kv
@@ -1413,13 +1219,7 @@ def __init__(
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if self.train_q_out:
@@ -1436,11 +1236,8 @@ def __call__(
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv:
- key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
- value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
- key = key.to(attn.to_q.weight.dtype)
- value = value.to(attn.to_q.weight.dtype)
-
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@@ -1491,16 +1288,10 @@ class SlicedAttnProcessor:
`attention_head_dim` must be a multiple of the `slice_size`.
"""
- def __init__(self, slice_size: int):
+ def __init__(self, slice_size):
self.slice_size = slice_size
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
input_ndim = hidden_states.ndim
@@ -1581,14 +1372,7 @@ class SlicedAttnAddedKVProcessor:
def __init__(self, slice_size):
self.slice_size = slice_size
- def __call__(
- self,
- attn: "Attention",
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- temb: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
@@ -1662,26 +1446,20 @@ def __call__(
class SpatialNorm(nn.Module):
"""
- Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
-
- Args:
- f_channels (`int`):
- The number of channels for input to group normalization layer, and output of the spatial norm layer.
- zq_channels (`int`):
- The number of channels for the quantized vector as described in the paper.
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
"""
def __init__(
self,
- f_channels: int,
- zq_channels: int,
+ f_channels,
+ zq_channels,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
- def forward(self, f: torch.FloatTensor, zq: torch.FloatTensor) -> torch.FloatTensor:
+ def forward(self, f, zq):
f_size = f.shape[-2:]
zq = F.interpolate(zq, size=f_size, mode="nearest")
norm_f = self.norm_layer(f)
@@ -1703,18 +1481,9 @@ class LoRAAttnProcessor(nn.Module):
The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
- kwargs (`dict`):
- Additional keyword arguments to pass to the `LoRALinearLayer` layers.
"""
- def __init__(
- self,
- hidden_size: int,
- cross_attention_dim: Optional[int] = None,
- rank: int = 4,
- network_alpha: Optional[int] = None,
- **kwargs,
- ):
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
super().__init__()
self.hidden_size = hidden_size
@@ -1741,7 +1510,7 @@ def __init__(
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
- def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ def __call__(self, attn: Attention, hidden_states, *args, **kwargs):
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
@@ -1776,18 +1545,9 @@ class LoRAAttnProcessor2_0(nn.Module):
The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
- kwargs (`dict`):
- Additional keyword arguments to pass to the `LoRALinearLayer` layers.
"""
- def __init__(
- self,
- hidden_size: int,
- cross_attention_dim: Optional[int] = None,
- rank: int = 4,
- network_alpha: Optional[int] = None,
- **kwargs,
- ):
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@@ -1816,7 +1576,7 @@ def __init__(
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
- def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ def __call__(self, attn: Attention, hidden_states, *args, **kwargs):
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
@@ -1855,17 +1615,16 @@ class LoRAXFormersAttnProcessor(nn.Module):
operator.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
- kwargs (`dict`):
- Additional keyword arguments to pass to the `LoRALinearLayer` layers.
+
"""
def __init__(
self,
- hidden_size: int,
- cross_attention_dim: int,
- rank: int = 4,
+ hidden_size,
+ cross_attention_dim,
+ rank=4,
attention_op: Optional[Callable] = None,
- network_alpha: Optional[int] = None,
+ network_alpha=None,
**kwargs,
):
super().__init__()
@@ -1895,7 +1654,7 @@ def __init__(
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
- def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ def __call__(self, attn: Attention, hidden_states, *args, **kwargs):
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
@@ -1928,19 +1687,10 @@ class LoRAAttnAddedKVProcessor(nn.Module):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
- network_alpha (`int`, *optional*):
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
- kwargs (`dict`):
- Additional keyword arguments to pass to the `LoRALinearLayer` layers.
+
"""
- def __init__(
- self,
- hidden_size: int,
- cross_attention_dim: Optional[int] = None,
- rank: int = 4,
- network_alpha: Optional[int] = None,
- ):
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
super().__init__()
self.hidden_size = hidden_size
@@ -1954,7 +1704,7 @@ def __init__(
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
- def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ def __call__(self, attn: Attention, hidden_states, *args, **kwargs):
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
@@ -1975,288 +1725,6 @@ def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **k
return attn.processor(attn, hidden_states, *args, **kwargs)
-class IPAdapterAttnProcessor(nn.Module):
- r"""
- Attention processor for IP-Adapater.
-
- Args:
- hidden_size (`int`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`):
- The number of channels in the `encoder_hidden_states`.
- num_tokens (`int`, defaults to 4):
- The context length of the image features.
- scale (`float`, defaults to 1.0):
- the weight scale of image prompt.
- """
-
- def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
- super().__init__()
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.num_tokens = num_tokens
- self.scale = scale
-
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
-
- def __call__(
- self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- scale=1.0,
- ):
- if scale != 1.0:
- logger.warning("`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.")
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- # split hidden states
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- encoder_hidden_states[:, end_pos:, :],
- )
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- query = attn.head_to_batch_dim(query)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # for ip-adapter
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
-
- ip_key = attn.head_to_batch_dim(ip_key)
- ip_value = attn.head_to_batch_dim(ip_value)
-
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
- ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
- ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
-
- hidden_states = hidden_states + self.scale * ip_hidden_states
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class IPAdapterAttnProcessor2_0(torch.nn.Module):
- r"""
- Attention processor for IP-Adapater for PyTorch 2.0.
-
- Args:
- hidden_size (`int`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`):
- The number of channels in the `encoder_hidden_states`.
- num_tokens (`int`, defaults to 4):
- The context length of the image features.
- scale (`float`, defaults to 1.0):
- the weight scale of image prompt.
- """
-
- def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
- super().__init__()
-
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.num_tokens = num_tokens
- self.scale = scale
-
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
-
- def __call__(
- self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- scale=1.0,
- ):
- if scale != 1.0:
- logger.warning("`scale` of IPAttnProcessor should be set by `set_ip_adapter_scale`.")
- residual = hidden_states
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- # split hidden states
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- encoder_hidden_states[:, end_pos:, :],
- )
-
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # for ip-adapter
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
-
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- ip_hidden_states = ip_hidden_states.to(query.dtype)
-
- hidden_states = hidden_states + self.scale * ip_hidden_states
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
-# this way torch.compile and co. will work as well
-class Kandi3AttnProcessor:
- r"""
- Default kandinsky3 proccesor for performing attention-related computations.
- """
-
- @staticmethod
- def _reshape(hid_states, h):
- b, n, f = hid_states.shape
- d = f // h
- return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)
-
- def __call__(
- self,
- attn,
- x,
- context,
- context_mask=None,
- ):
- query = self._reshape(attn.to_q(x), h=attn.num_heads)
- key = self._reshape(attn.to_k(context), h=attn.num_heads)
- value = self._reshape(attn.to_v(context), h=attn.num_heads)
-
- attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)
-
- if context_mask is not None:
- max_neg_value = -torch.finfo(attention_matrix.dtype).max
- context_mask = context_mask.unsqueeze(1).unsqueeze(1)
- attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
- attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)
-
- out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
- out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
- out = attn.to_out[0](out)
- return out
-
-
LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
@@ -2280,9 +1748,6 @@ def __call__(
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
- IPAdapterAttnProcessor,
- IPAdapterAttnProcessor2_0,
- Kandi3AttnProcessor,
)
AttentionProcessor = Union[
@@ -2297,7 +1762,7 @@ def __call__(
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
- # deprecated
+ # depraceted
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
diff --git a/src/diffusers/models/autoencoder_asym_kl.py b/src/diffusers/models/autoencoder_asym_kl.py
index 678e47234096..d8099120918b 100644
--- a/src/diffusers/models/autoencoder_asym_kl.py
+++ b/src/diffusers/models/autoencoder_asym_kl.py
@@ -18,7 +18,7 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils.accelerate_utils import apply_forward_hook
-from .modeling_outputs import AutoencoderKLOutput
+from .autoencoder_kl import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
@@ -65,11 +65,11 @@ def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
- down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
- down_block_out_channels: Tuple[int, ...] = (64,),
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ down_block_out_channels: Tuple[int] = (64,),
layers_per_down_block: int = 1,
- up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
- up_block_out_channels: Tuple[int, ...] = (64,),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ up_block_out_channels: Tuple[int] = (64,),
layers_per_up_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
@@ -108,13 +108,8 @@ def __init__(
self.use_slicing = False
self.use_tiling = False
- self.register_to_config(block_out_channels=up_block_out_channels)
- self.register_to_config(force_upcast=False)
-
@apply_forward_hook
- def encode(
- self, x: torch.FloatTensor, return_dict: bool = True
- ) -> Union[AutoencoderKLOutput, Tuple[torch.FloatTensor]]:
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
@@ -130,7 +125,7 @@ def _decode(
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
- ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z, image, mask)
@@ -143,11 +138,10 @@ def _decode(
def decode(
self,
z: torch.FloatTensor,
- generator: Optional[torch.Generator] = None,
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
- ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
decoded = self._decode(z, image, mask).sample
if not return_dict:
@@ -162,7 +156,7 @@ def forward(
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
- ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py
index 464bff9189dd..80d2cccd536d 100644
--- a/src/diffusers/models/autoencoder_kl.py
+++ b/src/diffusers/models/autoencoder_kl.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import torch
@@ -18,6 +19,7 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
+from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -26,11 +28,24 @@
AttnAddedKVProcessor,
AttnProcessor,
)
-from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
+@dataclass
+class AutoencoderKLOutput(BaseOutput):
+ """
+ Output of AutoencoderKL encoding method.
+
+ Args:
+ latent_dist (`DiagonalGaussianDistribution`):
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
+ """
+
+ latent_dist: "DiagonalGaussianDistribution"
+
+
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
@@ -279,9 +294,7 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod
return DecoderOutput(sample=dec)
@apply_forward_hook
- def decode(
- self, z: torch.FloatTensor, return_dict: bool = True, generator=None
- ) -> Union[DecoderOutput, torch.FloatTensor]:
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
@@ -307,13 +320,13 @@ def decode(
return DecoderOutput(sample=decoded)
- def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ def blend_v(self, a, b, blend_extent):
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
- def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ def blend_h(self, a, b, blend_extent):
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
diff --git a/src/diffusers/models/autoencoder_tiny.py b/src/diffusers/models/autoencoder_tiny.py
index 56ccf30e0402..407b1906bba4 100644
--- a/src/diffusers/models/autoencoder_tiny.py
+++ b/src/diffusers/models/autoencoder_tiny.py
@@ -14,7 +14,7 @@
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import Tuple, Union
import torch
@@ -91,24 +91,23 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
`force_upcast` can be set to `False` (see this fp16-friendly
[AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
"""
-
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
- in_channels: int = 3,
- out_channels: int = 3,
- encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
- decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
+ in_channels=3,
+ out_channels=3,
+ encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
+ decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64),
act_fn: str = "relu",
latent_channels: int = 4,
upsampling_scaling_factor: int = 2,
- num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
- num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
+ num_encoder_blocks: Tuple[int] = (1, 3, 3, 3),
+ num_decoder_blocks: Tuple[int] = (3, 3, 3, 1),
latent_magnitude: int = 3,
latent_shift: float = 0.5,
- force_upcast: bool = False,
+ force_upcast: float = False,
scaling_factor: float = 1.0,
):
super().__init__()
@@ -148,36 +147,33 @@ def __init__(
self.tile_sample_min_size = 512
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
- self.register_to_config(block_out_channels=decoder_block_out_channels)
- self.register_to_config(force_upcast=False)
-
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
+ def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (EncoderTiny, DecoderTiny)):
module.gradient_checkpointing = value
- def scale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ def scale_latents(self, x):
"""raw latents -> [0, 1]"""
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
- def unscale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ def unscale_latents(self, x):
"""[0, 1] -> raw latents"""
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
- def enable_slicing(self) -> None:
+ def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
- def disable_slicing(self) -> None:
+ def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
- def enable_tiling(self, use_tiling: bool = True) -> None:
+ def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
@@ -185,7 +181,7 @@ def enable_tiling(self, use_tiling: bool = True) -> None:
"""
self.use_tiling = use_tiling
- def disable_tiling(self) -> None:
+ def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
@@ -201,9 +197,13 @@ def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
Args:
x (`torch.FloatTensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
Returns:
- `torch.FloatTensor`: Encoded batch of images.
+ [`~models.autoencoder_tiny.AutoencoderTinyOutput`] or `tuple`:
+ If return_dict is True, a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] is returned, otherwise a
+ plain `tuple` is returned.
"""
# scale of encoder output relative to input
sf = self.spatial_scale_factor
@@ -249,9 +249,13 @@ def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
Args:
x (`torch.FloatTensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.autoencoder_tiny.AutoencoderTinyOutput`] instead of a plain tuple.
Returns:
- `torch.FloatTensor`: Encoded batch of images.
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
"""
# scale of decoder output relative to input
sf = self.spatial_scale_factor
@@ -303,9 +307,7 @@ def encode(
return AutoencoderTinyOutput(latents=output)
@apply_forward_hook
- def decode(
- self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
- ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
+ def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
if self.use_slicing and x.shape[0] > 1:
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
output = torch.cat(output)
diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py
index 3139bb2a5c6c..c0d2da9b8c5f 100644
--- a/src/diffusers/models/controlnet.py
+++ b/src/diffusers/models/controlnet.py
@@ -30,7 +30,12 @@
)
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
-from .unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block
+from .unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ DownBlock2D,
+ UNetMidBlock2DCrossAttn,
+ get_down_block,
+)
from .unet_2d_condition import UNet2DConditionModel
@@ -71,7 +76,7 @@ def __init__(
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
- block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
):
super().__init__()
@@ -166,9 +171,6 @@ class conditioning with `class_embed_type` equal to `None`.
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
- TODO(Patrick) - unused parameter.
- addition_embed_type_num_heads (`int`, defaults to 64):
- The number of heads to use for the `TextTimeEmbedding` layer.
"""
_supports_gradient_checkpointing = True
@@ -180,15 +182,14 @@ def __init__(
conditioning_channels: int = 3,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
- down_block_types: Tuple[str, ...] = (
+ down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
- mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
only_cross_attention: Union[bool, Tuple[bool]] = False,
- block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -196,11 +197,11 @@ def __init__(
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
- transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
- attention_head_dim: Union[int, Tuple[int, ...]] = 8,
- num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
@@ -210,9 +211,9 @@ def __init__(
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
- addition_embed_type_num_heads: int = 64,
+ addition_embed_type_num_heads=64,
):
super().__init__()
@@ -405,44 +406,28 @@ def __init__(
controlnet_block = zero_module(controlnet_block)
self.controlnet_mid_block = controlnet_block
- if mid_block_type == "UNetMidBlock2DCrossAttn":
- self.mid_block = UNetMidBlock2DCrossAttn(
- transformer_layers_per_block=transformer_layers_per_block[-1],
- in_channels=mid_block_channel,
- temb_channels=time_embed_dim,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_time_scale_shift=resnet_time_scale_shift,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads[-1],
- resnet_groups=norm_num_groups,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- )
- elif mid_block_type == "UNetMidBlock2D":
- self.mid_block = UNetMidBlock2D(
- in_channels=block_out_channels[-1],
- temb_channels=time_embed_dim,
- num_layers=0,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_groups=norm_num_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- add_attention=False,
- )
- else:
- raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=mid_block_channel,
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
controlnet_conditioning_channel_order: str = "rgb",
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
- conditioning_channels: int = 3,
):
r"""
Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
@@ -489,10 +474,8 @@ def from_unet(
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
- mid_block_type=unet.config.mid_block_type,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
- conditioning_channels=conditioning_channels,
)
if load_weights_from_unet:
@@ -587,7 +570,7 @@ def set_default_attn_processor(self):
self.set_attn_processor(processor, _remove_lora=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
- def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
+ def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
@@ -652,7 +635,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
+ def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
@@ -670,7 +653,7 @@ def forward(
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
- ) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
+ ) -> Union[ControlNetOutput, Tuple]:
"""
The [`ControlNetModel`] forward method.
@@ -811,16 +794,13 @@ def forward(
# 4. mid
if self.mid_block is not None:
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
- sample = self.mid_block(
- sample,
- emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- )
- else:
- sample = self.mid_block(sample, emb)
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
# 5. Control net blocks
@@ -837,6 +817,7 @@ def forward(
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
+
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py
index 34aaac549f8c..076e6183211b 100644
--- a/src/diffusers/models/controlnet_flax.py
+++ b/src/diffusers/models/controlnet_flax.py
@@ -46,10 +46,10 @@ class FlaxControlNetOutput(BaseOutput):
class FlaxControlNetConditioningEmbedding(nn.Module):
conditioning_embedding_channels: int
- block_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
+ block_out_channels: Tuple[int] = (16, 32, 96, 256)
dtype: jnp.dtype = jnp.float32
- def setup(self) -> None:
+ def setup(self):
self.conv_in = nn.Conv(
self.block_out_channels[0],
kernel_size=(3, 3),
@@ -87,7 +87,7 @@ def setup(self) -> None:
dtype=self.dtype,
)
- def __call__(self, conditioning: jnp.ndarray) -> jnp.ndarray:
+ def __call__(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = nn.silu(embedding)
@@ -146,20 +146,19 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
"""
-
sample_size: int = 32
in_channels: int = 4
- down_block_types: Tuple[str, ...] = (
+ down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
- only_cross_attention: Union[bool, Tuple[bool, ...]] = False
- block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
+ only_cross_attention: Union[bool, Tuple[bool]] = False
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2
- attention_head_dim: Union[int, Tuple[int, ...]] = 8
- num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
+ attention_head_dim: Union[int, Tuple[int]] = 8
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
@@ -167,7 +166,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True
freq_shift: int = 0
controlnet_conditioning_channel_order: str = "rgb"
- conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
+ conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors
@@ -183,7 +182,7 @@ def init_weights(self, rng: jax.Array) -> FrozenDict:
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
- def setup(self) -> None:
+ def setup(self):
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
@@ -313,21 +312,21 @@ def setup(self) -> None:
def __call__(
self,
- sample: jnp.ndarray,
- timesteps: Union[jnp.ndarray, float, int],
- encoder_hidden_states: jnp.ndarray,
- controlnet_cond: jnp.ndarray,
+ sample,
+ timesteps,
+ encoder_hidden_states,
+ controlnet_cond,
conditioning_scale: float = 1.0,
return_dict: bool = True,
train: bool = False,
- ) -> Union[FlaxControlNetOutput, Tuple[Tuple[jnp.ndarray, ...], jnp.ndarray]]:
+ ) -> Union[FlaxControlNetOutput, Tuple]:
r"""
Args:
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
timestep (`jnp.ndarray` or `float` or `int`): timesteps
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
- conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
+ conditioning_scale: (`float`) the scale factor for controlnet outputs
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
plain tuple.
@@ -336,8 +335,8 @@ def __call__(
Returns:
[`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
- [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
- `tuple`. When returning a tuple, the first element is the sample tensor.
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
+ When returning a tuple, the first element is the sample tensor.
"""
channel_order = self.controlnet_conditioning_channel_order
if channel_order == "bgr":
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index a377ae267411..d3422c8f58b2 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -66,22 +66,17 @@ def get_timestep_embedding(
return emb
-def get_2d_sincos_pos_embed(
- embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
-):
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
- if isinstance(grid_size, int):
- grid_size = (grid_size, grid_size)
-
- grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
- grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
- grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
+ grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
@@ -134,7 +129,6 @@ def __init__(
layer_norm=False,
flatten=True,
bias=True,
- interpolation_scale=1,
):
super().__init__()
@@ -150,41 +144,16 @@ def __init__(
else:
self.norm = None
- self.patch_size = patch_size
- # See:
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
- self.height, self.width = height // patch_size, width // patch_size
- self.base_size = height // patch_size
- self.interpolation_scale = interpolation_scale
- pos_embed = get_2d_sincos_pos_embed(
- embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
- )
+ pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
- height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
-
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
-
- # Interpolate positional embeddings if needed.
- # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
- if self.height != height or self.width != width:
- pos_embed = get_2d_sincos_pos_embed(
- embed_dim=self.pos_embed.shape[-1],
- grid_size=(height, width),
- base_size=self.base_size,
- interpolation_scale=self.interpolation_scale,
- )
- pos_embed = torch.from_numpy(pos_embed)
- pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
- else:
- pos_embed = self.pos_embed
-
- return (latent + pos_embed).to(latent.dtype)
+ return latent + self.pos_embed
class TimestepEmbedding(nn.Module):
@@ -282,33 +251,6 @@ def forward(self, x):
return out
-class SinusoidalPositionalEmbedding(nn.Module):
- """Apply positional information to a sequence of embeddings.
-
- Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
- them
-
- Args:
- embed_dim: (int): Dimension of the positional embedding.
- max_seq_length: Maximum sequence length to apply positional embeddings
-
- """
-
- def __init__(self, embed_dim: int, max_seq_length: int = 32):
- super().__init__()
- position = torch.arange(max_seq_length).unsqueeze(1)
- div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
- pe = torch.zeros(1, max_seq_length, embed_dim)
- pe[0, :, 0::2] = torch.sin(position * div_term)
- pe[0, :, 1::2] = torch.cos(position * div_term)
- self.register_buffer("pe", pe)
-
- def forward(self, x):
- _, seq_length, _ = x.shape
- x = x + self.pe[:, :seq_length]
- return x
-
-
class ImagePositionalEmbeddings(nn.Module):
"""
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
@@ -714,79 +656,3 @@ def forward(
objs = torch.cat([objs_text, objs_image], dim=1)
return objs
-
-
-class CombinedTimestepSizeEmbeddings(nn.Module):
- """
- For PixArt-Alpha.
-
- Reference:
- https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
- """
-
- def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
- super().__init__()
-
- self.outdim = size_emb_dim
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
-
- self.use_additional_conditions = use_additional_conditions
- if use_additional_conditions:
- self.use_additional_conditions = True
- self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
- self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
- self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
-
- def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module):
- if size.ndim == 1:
- size = size[:, None]
-
- if size.shape[0] != batch_size:
- size = size.repeat(batch_size // size.shape[0], 1)
- if size.shape[0] != batch_size:
- raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.")
-
- current_batch_size, dims = size.shape[0], size.shape[1]
- size = size.reshape(-1)
- size_freq = self.additional_condition_proj(size).to(size.dtype)
-
- size_emb = embedder(size_freq)
- size_emb = size_emb.reshape(current_batch_size, dims * self.outdim)
- return size_emb
-
- def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
- timesteps_proj = self.time_proj(timestep)
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
-
- if self.use_additional_conditions:
- resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder)
- aspect_ratio = self.apply_condition(
- aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder
- )
- conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1)
- else:
- conditioning = timesteps_emb
-
- return conditioning
-
-
-class CaptionProjection(nn.Module):
- """
- Projects caption embeddings. Also handles dropout for classifier-free guidance.
-
- Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
- """
-
- def __init__(self, in_features, hidden_size, num_tokens=120):
- super().__init__()
- self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
- self.act_1 = nn.GELU(approximate="tanh")
- self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
- self.register_buffer("y_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features**0.5))
-
- def forward(self, caption, force_drop_ids=None):
- hidden_states = self.linear_1(caption)
- hidden_states = self.act_1(hidden_states)
- hidden_states = self.linear_2(hidden_states)
- return hidden_states
diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py
index dd5c892990d3..88c2c45e4655 100644
--- a/src/diffusers/models/embeddings_flax.py
+++ b/src/diffusers/models/embeddings_flax.py
@@ -65,7 +65,6 @@ class FlaxTimestepEmbedding(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
-
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32
@@ -85,7 +84,6 @@ class FlaxTimesteps(nn.Module):
dim (`int`, *optional*, defaults to `32`):
Time step embedding dimension
"""
-
dim: int = 32
flip_sin_to_cos: bool = False
freq_shift: float = 1
diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py
index daac8f902cd6..a143c17458ad 100644
--- a/src/diffusers/models/lora.py
+++ b/src/diffusers/models/lora.py
@@ -12,60 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-# IMPORTANT: #
-###################################################################
-# ----------------------------------------------------------------#
-# This file is deprecated and will be removed soon #
-# (as soon as PEFT will become a required dependency for LoRA) #
-# ----------------------------------------------------------------#
-###################################################################
-
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
+from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
from ..utils import logging
-from ..utils.import_utils import is_transformers_available
-
-
-if is_transformers_available():
- from transformers import CLIPTextModel, CLIPTextModelWithProjection
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-def text_encoder_attn_modules(text_encoder):
- attn_modules = []
-
- if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
- for i, layer in enumerate(text_encoder.text_model.encoder.layers):
- name = f"text_model.encoder.layers.{i}.self_attn"
- mod = layer.self_attn
- attn_modules.append((name, mod))
- else:
- raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
-
- return attn_modules
-
-
-def text_encoder_mlp_modules(text_encoder):
- mlp_modules = []
-
- if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
- for i, layer in enumerate(text_encoder.text_model.encoder.layers):
- mlp_mod = layer.mlp
- name = f"text_model.encoder.layers.{i}.mlp"
- mlp_modules.append((name, mlp_mod))
- else:
- raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
-
- return mlp_modules
-
-
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
@@ -80,95 +39,6 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
mlp_module.fc2.lora_scale = lora_scale
-class PatchedLoraProjection(torch.nn.Module):
- def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
- super().__init__()
- from ..models.lora import LoRALinearLayer
-
- self.regular_linear_layer = regular_linear_layer
-
- device = self.regular_linear_layer.weight.device
-
- if dtype is None:
- dtype = self.regular_linear_layer.weight.dtype
-
- self.lora_linear_layer = LoRALinearLayer(
- self.regular_linear_layer.in_features,
- self.regular_linear_layer.out_features,
- network_alpha=network_alpha,
- device=device,
- dtype=dtype,
- rank=rank,
- )
-
- self.lora_scale = lora_scale
-
- # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved
- # when saving the whole text encoder model and when LoRA is unloaded or fused
- def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
- if self.lora_linear_layer is None:
- return self.regular_linear_layer.state_dict(
- *args, destination=destination, prefix=prefix, keep_vars=keep_vars
- )
-
- return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
-
- def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
- if self.lora_linear_layer is None:
- return
-
- dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
-
- w_orig = self.regular_linear_layer.weight.data.float()
- w_up = self.lora_linear_layer.up.weight.data.float()
- w_down = self.lora_linear_layer.down.weight.data.float()
-
- if self.lora_linear_layer.network_alpha is not None:
- w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
-
- fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
-
- if safe_fusing and torch.isnan(fused_weight).any().item():
- raise ValueError(
- "This LoRA weight seems to be broken. "
- f"Encountered NaN values when trying to fuse LoRA weights for {self}."
- "LoRA weights will not be fused."
- )
-
- self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
-
- # we can drop the lora layer now
- self.lora_linear_layer = None
-
- # offload the up and down matrices to CPU to not blow the memory
- self.w_up = w_up.cpu()
- self.w_down = w_down.cpu()
- self.lora_scale = lora_scale
-
- def _unfuse_lora(self):
- if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
- return
-
- fused_weight = self.regular_linear_layer.weight.data
- dtype, device = fused_weight.dtype, fused_weight.device
-
- w_up = self.w_up.to(device=device).float()
- w_down = self.w_down.to(device).float()
-
- unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
- self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
-
- self.w_up = None
- self.w_down = None
-
- def forward(self, input):
- if self.lora_scale is None:
- self.lora_scale = 1.0
- if self.lora_linear_layer is None:
- return self.regular_linear_layer(input)
- return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))
-
-
class LoRALinearLayer(nn.Module):
r"""
A linear layer that is used with LoRA.
diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py
index 0ea0819ca07a..ea4d1bfea548 100644
--- a/src/diffusers/models/modeling_flax_utils.py
+++ b/src/diffusers/models/modeling_flax_utils.py
@@ -52,7 +52,6 @@ class FlaxModelMixin(PushToHubMixin):
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
"""
-
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_flax_internal_args = ["name", "parent", "dtype"]
@@ -437,7 +436,7 @@ def from_pretrained(
# make sure all arrays are stored as jnp.ndarray
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
# https://github.com/google/flax/issues/1261
- state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
+ state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
# flatten dicts
state = flatten_dict(state)
diff --git a/src/diffusers/models/modeling_pytorch_flax_utils.py b/src/diffusers/models/modeling_pytorch_flax_utils.py
index 17b521b00145..a61638ad02f7 100644
--- a/src/diffusers/models/modeling_pytorch_flax_utils.py
+++ b/src/diffusers/models/modeling_pytorch_flax_utils.py
@@ -1,161 +1,161 @@
-# coding=utf-8
-# Copyright 2023 The HuggingFace Inc. team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-""" PyTorch - Flax general utilities."""
-
-from pickle import UnpicklingError
-
-import jax
-import jax.numpy as jnp
-import numpy as np
-from flax.serialization import from_bytes
-from flax.traverse_util import flatten_dict
-
-from ..utils import logging
-
-
-logger = logging.get_logger(__name__)
-
-
-#####################
-# Flax => PyTorch #
-#####################
-
-
-# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
-def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
- try:
- with open(model_file, "rb") as flax_state_f:
- flax_state = from_bytes(None, flax_state_f.read())
- except UnpicklingError as e:
- try:
- with open(model_file) as f:
- if f.read().startswith("version"):
- raise OSError(
- "You seem to have cloned a repository without having git-lfs installed. Please"
- " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
- " folder you cloned."
- )
- else:
- raise ValueError from e
- except (UnicodeDecodeError, ValueError):
- raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
-
- return load_flax_weights_in_pytorch_model(pt_model, flax_state)
-
-
-def load_flax_weights_in_pytorch_model(pt_model, flax_state):
- """Load flax checkpoints in a PyTorch model"""
-
- try:
- import torch # noqa: F401
- except ImportError:
- logger.error(
- "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
- " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
- " instructions."
- )
- raise
-
- # check if we have bf16 weights
- is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
- if any(is_type_bf16):
- # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
-
- # and bf16 is not fully supported in PT yet.
- logger.warning(
- "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
- "before loading those in PyTorch model."
- )
- flax_state = jax.tree_util.tree_map(
- lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
- )
-
- pt_model.base_model_prefix = ""
-
- flax_state_dict = flatten_dict(flax_state, sep=".")
- pt_model_dict = pt_model.state_dict()
-
- # keep track of unexpected & missing keys
- unexpected_keys = []
- missing_keys = set(pt_model_dict.keys())
-
- for flax_key_tuple, flax_tensor in flax_state_dict.items():
- flax_key_tuple_array = flax_key_tuple.split(".")
-
- if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
- flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
- flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
- elif flax_key_tuple_array[-1] == "kernel":
- flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
- flax_tensor = flax_tensor.T
- elif flax_key_tuple_array[-1] == "scale":
- flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
-
- if "time_embedding" not in flax_key_tuple_array:
- for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
- flax_key_tuple_array[i] = (
- flax_key_tuple_string.replace("_0", ".0")
- .replace("_1", ".1")
- .replace("_2", ".2")
- .replace("_3", ".3")
- .replace("_4", ".4")
- .replace("_5", ".5")
- .replace("_6", ".6")
- .replace("_7", ".7")
- .replace("_8", ".8")
- .replace("_9", ".9")
- )
-
- flax_key = ".".join(flax_key_tuple_array)
-
- if flax_key in pt_model_dict:
- if flax_tensor.shape != pt_model_dict[flax_key].shape:
- raise ValueError(
- f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
- f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
- )
- else:
- # add weight to pytorch dict
- flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
- pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
- # remove from missing keys
- missing_keys.remove(flax_key)
- else:
- # weight is not expected by PyTorch model
- unexpected_keys.append(flax_key)
-
- pt_model.load_state_dict(pt_model_dict)
-
- # re-transform missing_keys to list
- missing_keys = list(missing_keys)
-
- if len(unexpected_keys) > 0:
- logger.warning(
- "Some weights of the Flax model were not used when initializing the PyTorch model"
- f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
- f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
- " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
- f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
- " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
- " FlaxBertForSequenceClassification model)."
- )
- if len(missing_keys) > 0:
- logger.warning(
- f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
- f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
- " use it for predictions and inference."
- )
-
- return pt_model
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch - Flax general utilities."""
+
+from pickle import UnpicklingError
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.serialization import from_bytes
+from flax.traverse_util import flatten_dict
+
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+#####################
+# Flax => PyTorch #
+#####################
+
+
+# from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py#L224-L352
+def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
+ try:
+ with open(model_file, "rb") as flax_state_f:
+ flax_state = from_bytes(None, flax_state_f.read())
+ except UnpicklingError as e:
+ try:
+ with open(model_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please"
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
+ " folder you cloned."
+ )
+ else:
+ raise ValueError from e
+ except (UnicodeDecodeError, ValueError):
+ raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
+
+ return load_flax_weights_in_pytorch_model(pt_model, flax_state)
+
+
+def load_flax_weights_in_pytorch_model(pt_model, flax_state):
+ """Load flax checkpoints in a PyTorch model"""
+
+ try:
+ import torch # noqa: F401
+ except ImportError:
+ logger.error(
+ "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
+ " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
+ " instructions."
+ )
+ raise
+
+ # check if we have bf16 weights
+ is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
+ if any(is_type_bf16):
+ # convert all weights to fp32 if they are bf16 since torch.from_numpy can-not handle bf16
+
+ # and bf16 is not fully supported in PT yet.
+ logger.warning(
+ "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
+ "before loading those in PyTorch model."
+ )
+ flax_state = jax.tree_util.tree_map(
+ lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
+ )
+
+ pt_model.base_model_prefix = ""
+
+ flax_state_dict = flatten_dict(flax_state, sep=".")
+ pt_model_dict = pt_model.state_dict()
+
+ # keep track of unexpected & missing keys
+ unexpected_keys = []
+ missing_keys = set(pt_model_dict.keys())
+
+ for flax_key_tuple, flax_tensor in flax_state_dict.items():
+ flax_key_tuple_array = flax_key_tuple.split(".")
+
+ if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
+ flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
+ elif flax_key_tuple_array[-1] == "kernel":
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
+ flax_tensor = flax_tensor.T
+ elif flax_key_tuple_array[-1] == "scale":
+ flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
+
+ if "time_embedding" not in flax_key_tuple_array:
+ for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
+ flax_key_tuple_array[i] = (
+ flax_key_tuple_string.replace("_0", ".0")
+ .replace("_1", ".1")
+ .replace("_2", ".2")
+ .replace("_3", ".3")
+ .replace("_4", ".4")
+ .replace("_5", ".5")
+ .replace("_6", ".6")
+ .replace("_7", ".7")
+ .replace("_8", ".8")
+ .replace("_9", ".9")
+ )
+
+ flax_key = ".".join(flax_key_tuple_array)
+
+ if flax_key in pt_model_dict:
+ if flax_tensor.shape != pt_model_dict[flax_key].shape:
+ raise ValueError(
+ f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
+ f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
+ )
+ else:
+ # add weight to pytorch dict
+ flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
+ pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
+ # remove from missing keys
+ missing_keys.remove(flax_key)
+ else:
+ # weight is not expected by PyTorch model
+ unexpected_keys.append(flax_key)
+
+ pt_model.load_state_dict(pt_model_dict)
+
+ # re-transform missing_keys to list
+ missing_keys = list(missing_keys)
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ "Some weights of the Flax model were not used when initializing the PyTorch model"
+ f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
+ f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
+ " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
+ f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
+ " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
+ " FlaxBertForSequenceClassification model)."
+ )
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
+ f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
+ " use it for predictions and inference."
+ )
+
+ return pt_model
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index 644c52f103fa..7639f75152a5 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -18,14 +18,13 @@
import itertools
import os
import re
-from collections import OrderedDict
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
import safetensors
import torch
from huggingface_hub import create_repo
-from torch import Tensor, nn
+from torch import Tensor, device, nn
from .. import __version__
from ..utils import (
@@ -62,7 +61,7 @@
from accelerate.utils.versions import is_torch_version
-def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
+def get_parameter_device(parameter: torch.nn.Module):
try:
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
return next(parameters_and_buffers).device
@@ -78,7 +77,7 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
return first_tuple[1].device
-def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
+def get_parameter_dtype(parameter: torch.nn.Module):
try:
params = tuple(parameter.parameters())
if len(params) > 0:
@@ -131,18 +130,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
)
-def load_model_dict_into_meta(
- model,
- state_dict: OrderedDict,
- device: Optional[Union[str, torch.device]] = None,
- dtype: Optional[Union[str, torch.dtype]] = None,
- model_name_or_path: Optional[str] = None,
-) -> List[str]:
+def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
- accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
-
unexpected_keys = []
empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
@@ -156,6 +147,7 @@ def load_model_dict_into_meta(
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
+ accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
else:
@@ -163,7 +155,7 @@ def load_model_dict_into_meta(
return unexpected_keys
-def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
+def _load_state_dict_into_model(model_to_load, state_dict):
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
state_dict = state_dict.copy()
@@ -171,7 +163,7 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
- def load(module: torch.nn.Module, prefix: str = ""):
+ def load(module: torch.nn.Module, prefix=""):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
module._load_from_state_dict(*args)
@@ -193,7 +185,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
"""
-
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False
@@ -228,7 +219,7 @@ def is_gradient_checkpointing(self) -> bool:
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
- def enable_gradient_checkpointing(self) -> None:
+ def enable_gradient_checkpointing(self):
"""
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks).
@@ -237,7 +228,7 @@ def enable_gradient_checkpointing(self) -> None:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))
- def disable_gradient_checkpointing(self) -> None:
+ def disable_gradient_checkpointing(self):
"""
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
*checkpoint activations* in other frameworks).
@@ -262,7 +253,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)
- def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
r"""
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
@@ -298,7 +289,7 @@ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Call
"""
self.set_use_memory_efficient_attention_xformers(True, attention_op)
- def disable_xformers_memory_efficient_attention(self) -> None:
+ def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
"""
@@ -455,7 +446,7 @@ def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
- save_function: Optional[Callable] = None,
+ save_function: Callable = None,
safe_serialization: bool = True,
variant: Optional[str] = None,
push_to_hub: bool = False,
@@ -918,10 +909,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
def _load_pretrained_model(
cls,
model,
- state_dict: OrderedDict,
+ state_dict,
resolved_archive_file,
- pretrained_model_name_or_path: Union[str, os.PathLike],
- ignore_mismatched_sizes: bool = False,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=False,
):
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
@@ -1019,7 +1010,7 @@ def _find_mismatched_keys(
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
@property
- def device(self) -> torch.device:
+ def device(self) -> device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
@@ -1071,7 +1062,7 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
- def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
+ def _convert_deprecated_attention_blocks(self, state_dict):
deprecated_attention_block_paths = []
def recursive_find_attn_block(name, module):
@@ -1115,7 +1106,7 @@ def recursive_find_attn_block(name, module):
if f"{path}.proj_attn.bias" in state_dict:
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
- def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
+ def _temp_convert_self_to_deprecated_attention_blocks(self):
deprecated_attention_block_modules = []
def recursive_find_attn_block(module):
@@ -1142,10 +1133,10 @@ def recursive_find_attn_block(module):
del module.to_v
del module.to_out
- def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
deprecated_attention_block_modules = []
- def recursive_find_attn_block(module) -> None:
+ def recursive_find_attn_block(module):
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
deprecated_attention_block_modules.append(module)
diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py
index 970d2be05b7a..80bf269fc4e3 100644
--- a/src/diffusers/models/resnet.py
+++ b/src/diffusers/models/resnet.py
@@ -22,9 +22,9 @@
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation
+from .attention import AdaGroupNorm
from .attention_processor import SpatialNorm
from .lora import LoRACompatibleConv, LoRACompatibleLinear
-from .normalization import AdaGroupNorm
class Upsample1D(nn.Module):
@@ -164,12 +164,7 @@ def __init__(
else:
self.Conv2d_0 = conv
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- output_size: Optional[int] = None,
- scale: float = 1.0,
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
@@ -261,7 +256,7 @@ def __init__(
else:
self.conv = conv
- def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
+ def forward(self, hidden_states, scale: float = 1.0):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
@@ -285,7 +280,7 @@ class FirUpsample2D(nn.Module):
"""A 2D FIR upsampling layer with an optional convolution.
Parameters:
- channels (`int`, optional):
+ channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
@@ -297,7 +292,7 @@ class FirUpsample2D(nn.Module):
def __init__(
self,
- channels: Optional[int] = None,
+ channels: int = None,
out_channels: Optional[int] = None,
use_conv: bool = False,
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
@@ -312,12 +307,12 @@ def __init__(
def _upsample_2d(
self,
- hidden_states: torch.FloatTensor,
- weight: Optional[torch.FloatTensor] = None,
+ hidden_states: torch.Tensor,
+ weight: Optional[torch.Tensor] = None,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
- ) -> torch.FloatTensor:
+ ) -> torch.Tensor:
"""Fused `upsample_2d()` followed by `Conv2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
@@ -325,21 +320,17 @@ def _upsample_2d(
arbitrary order.
Args:
- hidden_states (`torch.FloatTensor`):
- Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
- weight (`torch.FloatTensor`, *optional*):
- Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
- performed by `inChannels = x.shape[0] // numGroups`.
- kernel (`torch.FloatTensor`, *optional*):
- FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
- corresponds to nearest-neighbor upsampling.
- factor (`int`, *optional*): Integer upsampling factor (default: 2).
- gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
- output (`torch.FloatTensor`):
- Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
- datatype as `hidden_states`.
+ output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
+ datatype as `hidden_states`.
"""
assert isinstance(factor, int) and factor >= 1
@@ -382,11 +373,7 @@ def _upsample_2d(
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
inverse_conv = F.conv_transpose2d(
- hidden_states,
- weight,
- stride=stride,
- output_padding=output_padding,
- padding=0,
+ hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
)
output = upfirdn2d_native(
@@ -405,7 +392,7 @@ def _upsample_2d(
return output
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.use_conv:
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -431,7 +418,7 @@ class FirDownsample2D(nn.Module):
def __init__(
self,
- channels: Optional[int] = None,
+ channels: int = None,
out_channels: Optional[int] = None,
use_conv: bool = False,
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
@@ -446,35 +433,30 @@ def __init__(
def _downsample_2d(
self,
- hidden_states: torch.FloatTensor,
- weight: Optional[torch.FloatTensor] = None,
+ hidden_states: torch.Tensor,
+ weight: Optional[torch.Tensor] = None,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
- ) -> torch.FloatTensor:
+ ) -> torch.Tensor:
"""Fused `Conv2d()` followed by `downsample_2d()`.
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
arbitrary order.
Args:
- hidden_states (`torch.FloatTensor`):
- Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
- weight (`torch.FloatTensor`, *optional*):
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ weight:
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
performed by `inChannels = x.shape[0] // numGroups`.
- kernel (`torch.FloatTensor`, *optional*):
- FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
- corresponds to average pooling.
- factor (`int`, *optional*, default to `2`):
- Integer downsampling factor.
- gain (`float`, *optional*, default to `1.0`):
- Scaling factor for signal magnitude.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
+ factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
- output (`torch.FloatTensor`):
- Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
- datatype as `x`.
+ output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
+ same datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
@@ -510,7 +492,7 @@ def _downsample_2d(
return output
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.use_conv:
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -537,14 +519,7 @@ def __init__(self, pad_mode: str = "reflect"):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
- weight = inputs.new_zeros(
- [
- inputs.shape[1],
- inputs.shape[1],
- self.kernel.shape[0],
- self.kernel.shape[1],
- ]
- )
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
@@ -567,14 +542,7 @@ def __init__(self, pad_mode: str = "reflect"):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
- weight = inputs.new_zeros(
- [
- inputs.shape[1],
- inputs.shape[1],
- self.kernel.shape[0],
- self.kernel.shape[1],
- ]
- )
+ weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
@@ -711,20 +679,10 @@ def __init__(
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
- in_channels,
- conv_2d_out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=conv_shortcut_bias,
+ in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
)
- def forward(
- self,
- input_tensor: torch.FloatTensor,
- temb: torch.FloatTensor,
- scale: float = 1.0,
- ) -> torch.FloatTensor:
+ def forward(self, input_tensor, temb, scale: float = 1.0):
hidden_states = input_tensor
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
@@ -820,22 +778,16 @@ class Conv1dBlock(nn.Module):
out_channels (`int`): Number of output channels.
kernel_size (`int` or `tuple`): Size of the convolving kernel.
n_groups (`int`, default `8`): Number of groups to separate the channels into.
- activation (`str`, defaults to `mish`): Name of the activation function.
"""
def __init__(
- self,
- inp_channels: int,
- out_channels: int,
- kernel_size: Union[int, Tuple[int, int]],
- n_groups: int = 8,
- activation: str = "mish",
+ self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8
):
super().__init__()
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.group_norm = nn.GroupNorm(n_groups, out_channels)
- self.mish = get_activation(activation)
+ self.mish = nn.Mish()
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
intermediate_repr = self.conv1d(inputs)
@@ -856,22 +808,16 @@ class ResidualTemporalBlock1D(nn.Module):
out_channels (`int`): Number of output channels.
embed_dim (`int`): Embedding dimension.
kernel_size (`int` or `tuple`): Size of the convolving kernel.
- activation (`str`, defaults `mish`): It is possible to choose the right activation function.
"""
def __init__(
- self,
- inp_channels: int,
- out_channels: int,
- embed_dim: int,
- kernel_size: Union[int, Tuple[int, int]] = 5,
- activation: str = "mish",
+ self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5
):
super().__init__()
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
- self.time_emb_act = get_activation(activation)
+ self.time_emb_act = nn.Mish()
self.time_emb = nn.Linear(embed_dim, out_channels)
self.residual_conv = (
@@ -895,11 +841,8 @@ def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
def upsample_2d(
- hidden_states: torch.FloatTensor,
- kernel: Optional[torch.FloatTensor] = None,
- factor: int = 2,
- gain: float = 1,
-) -> torch.FloatTensor:
+ hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
+) -> torch.Tensor:
r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
@@ -907,19 +850,14 @@ def upsample_2d(
a: multiple of the upsampling factor.
Args:
- hidden_states (`torch.FloatTensor`):
- Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
- kernel (`torch.FloatTensor`, *optional*):
- FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
- corresponds to nearest-neighbor upsampling.
- factor (`int`, *optional*, default to `2`):
- Integer upsampling factor.
- gain (`float`, *optional*, default to `1.0`):
- Scaling factor for signal magnitude (default: 1.0).
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
- output (`torch.FloatTensor`):
- Tensor of the shape `[N, C, H * factor, W * factor]`
+ output: Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if kernel is None:
@@ -942,11 +880,8 @@ def upsample_2d(
def downsample_2d(
- hidden_states: torch.FloatTensor,
- kernel: Optional[torch.FloatTensor] = None,
- factor: int = 2,
- gain: float = 1,
-) -> torch.FloatTensor:
+ hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
+) -> torch.Tensor:
r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
@@ -954,19 +889,14 @@ def downsample_2d(
shape is a multiple of the downsampling factor.
Args:
- hidden_states (`torch.FloatTensor`)
- Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
- kernel (`torch.FloatTensor`, *optional*):
- FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
- corresponds to average pooling.
- factor (`int`, *optional*, default to `2`):
- Integer downsampling factor.
- gain (`float`, *optional*, default to `1.0`):
- Scaling factor for signal magnitude.
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2).
+ gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
- output (`torch.FloatTensor`):
- Tensor of the shape `[N, C, H // factor, W // factor]`
+ output: Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
@@ -981,20 +911,13 @@ def downsample_2d(
kernel = kernel * gain
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
- hidden_states,
- kernel.to(device=hidden_states.device),
- down=factor,
- pad=((pad_value + 1) // 2, pad_value // 2),
+ hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
)
return output
def upfirdn2d_native(
- tensor: torch.Tensor,
- kernel: torch.Tensor,
- up: int = 1,
- down: int = 1,
- pad: Tuple[int, int] = (0, 0),
+ tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
) -> torch.Tensor:
up_x = up_y = up
down_x = down_y = down
@@ -1050,13 +973,7 @@ class TemporalConvLayer(nn.Module):
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
"""
- def __init__(
- self,
- in_dim: int,
- out_dim: Optional[int] = None,
- dropout: float = 0.0,
- norm_num_groups: int = 32,
- ):
+ def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0):
super().__init__()
out_dim = out_dim or in_dim
self.in_dim = in_dim
@@ -1064,24 +981,22 @@ def __init__(
# conv layers
self.conv1 = nn.Sequential(
- nn.GroupNorm(norm_num_groups, in_dim),
- nn.SiLU(),
- nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
+ nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
)
self.conv2 = nn.Sequential(
- nn.GroupNorm(norm_num_groups, out_dim),
+ nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
)
self.conv3 = nn.Sequential(
- nn.GroupNorm(norm_num_groups, out_dim),
+ nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
)
self.conv4 = nn.Sequential(
- nn.GroupNorm(norm_num_groups, out_dim),
+ nn.GroupNorm(32, out_dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
@@ -1108,261 +1023,3 @@ def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Ten
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
)
return hidden_states
-
-
-class TemporalResnetBlock(nn.Module):
- r"""
- A Resnet block.
-
- Parameters:
- in_channels (`int`): The number of channels in the input.
- out_channels (`int`, *optional*, default to be `None`):
- The number of output channels for the first conv2d layer. If None, same as `in_channels`.
- temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
- eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
- """
-
- def __init__(
- self,
- in_channels: int,
- out_channels: Optional[int] = None,
- temb_channels: int = 512,
- eps: float = 1e-6,
- ):
- super().__init__()
- self.in_channels = in_channels
- out_channels = in_channels if out_channels is None else out_channels
- self.out_channels = out_channels
-
- kernel_size = (3, 1, 1)
- padding = [k // 2 for k in kernel_size]
-
- self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True)
- self.conv1 = nn.Conv3d(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=1,
- padding=padding,
- )
-
- if temb_channels is not None:
- self.time_emb_proj = nn.Linear(temb_channels, out_channels)
- else:
- self.time_emb_proj = None
-
- self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True)
-
- self.dropout = torch.nn.Dropout(0.0)
- self.conv2 = nn.Conv3d(
- out_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=1,
- padding=padding,
- )
-
- self.nonlinearity = get_activation("silu")
-
- self.use_in_shortcut = self.in_channels != out_channels
-
- self.conv_shortcut = None
- if self.use_in_shortcut:
- self.conv_shortcut = nn.Conv3d(
- in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0,
- )
-
- def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
- hidden_states = input_tensor
-
- hidden_states = self.norm1(hidden_states)
- hidden_states = self.nonlinearity(hidden_states)
- hidden_states = self.conv1(hidden_states)
-
- if self.time_emb_proj is not None:
- temb = self.nonlinearity(temb)
- temb = self.time_emb_proj(temb)[:, :, :, None, None]
- temb = temb.permute(0, 2, 1, 3, 4)
- hidden_states = hidden_states + temb
-
- hidden_states = self.norm2(hidden_states)
- hidden_states = self.nonlinearity(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.conv2(hidden_states)
-
- if self.conv_shortcut is not None:
- input_tensor = self.conv_shortcut(input_tensor)
-
- output_tensor = input_tensor + hidden_states
-
- return output_tensor
-
-
-# VideoResBlock
-class SpatioTemporalResBlock(nn.Module):
- r"""
- A SpatioTemporal Resnet block.
-
- Parameters:
- in_channels (`int`): The number of channels in the input.
- out_channels (`int`, *optional*, default to be `None`):
- The number of output channels for the first conv2d layer. If None, same as `in_channels`.
- temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
- eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.
- temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.
- merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.
- merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
- The merge strategy to use for the temporal mixing.
- switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
- If `True`, switch the spatial and temporal mixing.
- """
-
- def __init__(
- self,
- in_channels: int,
- out_channels: Optional[int] = None,
- temb_channels: int = 512,
- eps: float = 1e-6,
- temporal_eps: Optional[float] = None,
- merge_factor: float = 0.5,
- merge_strategy="learned_with_images",
- switch_spatial_to_temporal_mix: bool = False,
- ):
- super().__init__()
-
- self.spatial_res_block = ResnetBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=eps,
- )
-
- self.temporal_res_block = TemporalResnetBlock(
- in_channels=out_channels if out_channels is not None else in_channels,
- out_channels=out_channels if out_channels is not None else in_channels,
- temb_channels=temb_channels,
- eps=temporal_eps if temporal_eps is not None else eps,
- )
-
- self.time_mixer = AlphaBlender(
- alpha=merge_factor,
- merge_strategy=merge_strategy,
- switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
- )
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- image_only_indicator: Optional[torch.Tensor] = None,
- ):
- num_frames = image_only_indicator.shape[-1]
- hidden_states = self.spatial_res_block(hidden_states, temb)
-
- batch_frames, channels, height, width = hidden_states.shape
- batch_size = batch_frames // num_frames
-
- hidden_states_mix = (
- hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
- )
- hidden_states = (
- hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
- )
-
- if temb is not None:
- temb = temb.reshape(batch_size, num_frames, -1)
-
- hidden_states = self.temporal_res_block(hidden_states, temb)
- hidden_states = self.time_mixer(
- x_spatial=hidden_states_mix,
- x_temporal=hidden_states,
- image_only_indicator=image_only_indicator,
- )
-
- hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
- return hidden_states
-
-
-class AlphaBlender(nn.Module):
- r"""
- A module to blend spatial and temporal features.
-
- Parameters:
- alpha (`float`): The initial value of the blending factor.
- merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
- The merge strategy to use for the temporal mixing.
- switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
- If `True`, switch the spatial and temporal mixing.
- """
-
- strategies = ["learned", "fixed", "learned_with_images"]
-
- def __init__(
- self,
- alpha: float,
- merge_strategy: str = "learned_with_images",
- switch_spatial_to_temporal_mix: bool = False,
- ):
- super().__init__()
- self.merge_strategy = merge_strategy
- self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE
-
- if merge_strategy not in self.strategies:
- raise ValueError(f"merge_strategy needs to be in {self.strategies}")
-
- if self.merge_strategy == "fixed":
- self.register_buffer("mix_factor", torch.Tensor([alpha]))
- elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
- self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
- else:
- raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
-
- def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
- if self.merge_strategy == "fixed":
- alpha = self.mix_factor
-
- elif self.merge_strategy == "learned":
- alpha = torch.sigmoid(self.mix_factor)
-
- elif self.merge_strategy == "learned_with_images":
- if image_only_indicator is None:
- raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy")
-
- alpha = torch.where(
- image_only_indicator.bool(),
- torch.ones(1, 1, device=image_only_indicator.device),
- torch.sigmoid(self.mix_factor)[..., None],
- )
-
- # (batch, channel, frames, height, width)
- if ndims == 5:
- alpha = alpha[:, None, :, None, None]
- # (batch*frames, height*width, channels)
- elif ndims == 3:
- alpha = alpha.reshape(-1)[:, None, None]
- else:
- raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
-
- else:
- raise NotImplementedError
-
- return alpha
-
- def forward(
- self,
- x_spatial: torch.Tensor,
- x_temporal: torch.Tensor,
- image_only_indicator: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
- alpha = alpha.to(x_spatial.dtype)
-
- if self.switch_spatial_to_temporal_mix:
- alpha = 1.0 - alpha
-
- x = alpha * x_spatial + (1.0 - alpha) * x_temporal
- return x
diff --git a/src/diffusers/models/t5_film_transformer.py b/src/diffusers/models/t5_film_transformer.py
index 26ff3f6b8127..1c41e656a9db 100644
--- a/src/diffusers/models/t5_film_transformer.py
+++ b/src/diffusers/models/t5_film_transformer.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
-from typing import Optional, Tuple
import torch
from torch import nn
@@ -24,28 +23,6 @@
class T5FilmDecoder(ModelMixin, ConfigMixin):
- r"""
- T5 style decoder with FiLM conditioning.
-
- Args:
- input_dims (`int`, *optional*, defaults to `128`):
- The number of input dimensions.
- targets_length (`int`, *optional*, defaults to `256`):
- The length of the targets.
- d_model (`int`, *optional*, defaults to `768`):
- Size of the input hidden states.
- num_layers (`int`, *optional*, defaults to `12`):
- The number of `DecoderLayer`'s to use.
- num_heads (`int`, *optional*, defaults to `12`):
- The number of attention heads to use.
- d_kv (`int`, *optional*, defaults to `64`):
- Size of the key-value projection vectors.
- d_ff (`int`, *optional*, defaults to `2048`):
- The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
- dropout_rate (`float`, *optional*, defaults to `0.1`):
- Dropout probability.
- """
-
@register_to_config
def __init__(
self,
@@ -86,7 +63,7 @@ def __init__(
self.post_dropout = nn.Dropout(p=dropout_rate)
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
- def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
+ def encoder_decoder_mask(self, query_input, key_input):
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
return mask.unsqueeze(-3)
@@ -148,27 +125,7 @@ def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time)
class DecoderLayer(nn.Module):
- r"""
- T5 decoder layer.
-
- Args:
- d_model (`int`):
- Size of the input hidden states.
- d_kv (`int`):
- Size of the key-value projection vectors.
- num_heads (`int`):
- Number of attention heads.
- d_ff (`int`):
- Size of the intermediate feed-forward layer.
- dropout_rate (`float`):
- Dropout probability.
- layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
- A small value used for numerical stability to avoid dividing by zero.
- """
-
- def __init__(
- self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
- ):
+ def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6):
super().__init__()
self.layer = nn.ModuleList()
@@ -195,13 +152,13 @@ def __init__(
def forward(
self,
- hidden_states: torch.FloatTensor,
- conditioning_emb: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.Tensor] = None,
+ hidden_states,
+ conditioning_emb=None,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
encoder_decoder_position_bias=None,
- ) -> Tuple[torch.FloatTensor]:
+ ):
hidden_states = self.layer[0](
hidden_states,
conditioning_emb=conditioning_emb,
@@ -226,21 +183,7 @@ def forward(
class T5LayerSelfAttentionCond(nn.Module):
- r"""
- T5 style self-attention layer with conditioning.
-
- Args:
- d_model (`int`):
- Size of the input hidden states.
- d_kv (`int`):
- Size of the key-value projection vectors.
- num_heads (`int`):
- Number of attention heads.
- dropout_rate (`float`):
- Dropout probability.
- """
-
- def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
+ def __init__(self, d_model, d_kv, num_heads, dropout_rate):
super().__init__()
self.layer_norm = T5LayerNorm(d_model)
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
@@ -249,10 +192,10 @@ def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float)
def forward(
self,
- hidden_states: torch.FloatTensor,
- conditioning_emb: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ hidden_states,
+ conditioning_emb=None,
+ attention_mask=None,
+ ):
# pre_self_attention_layer_norm
normed_hidden_states = self.layer_norm(hidden_states)
@@ -268,23 +211,7 @@ def forward(
class T5LayerCrossAttention(nn.Module):
- r"""
- T5 style cross-attention layer.
-
- Args:
- d_model (`int`):
- Size of the input hidden states.
- d_kv (`int`):
- Size of the key-value projection vectors.
- num_heads (`int`):
- Number of attention heads.
- dropout_rate (`float`):
- Dropout probability.
- layer_norm_epsilon (`float`):
- A small value used for numerical stability to avoid dividing by zero.
- """
-
- def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
+ def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon):
super().__init__()
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
@@ -292,10 +219,10 @@ def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float,
def forward(
self,
- hidden_states: torch.FloatTensor,
- key_value_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ hidden_states,
+ key_value_states=None,
+ attention_mask=None,
+ ):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention(
normed_hidden_states,
@@ -307,30 +234,14 @@ def forward(
class T5LayerFFCond(nn.Module):
- r"""
- T5 style feed-forward conditional layer.
-
- Args:
- d_model (`int`):
- Size of the input hidden states.
- d_ff (`int`):
- Size of the intermediate feed-forward layer.
- dropout_rate (`float`):
- Dropout probability.
- layer_norm_epsilon (`float`):
- A small value used for numerical stability to avoid dividing by zero.
- """
-
- def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
+ def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon):
super().__init__()
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate)
- def forward(
- self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states, conditioning_emb=None):
forwarded_states = self.layer_norm(hidden_states)
if conditioning_emb is not None:
forwarded_states = self.film(forwarded_states, conditioning_emb)
@@ -341,19 +252,7 @@ def forward(
class T5DenseGatedActDense(nn.Module):
- r"""
- T5 style feed-forward layer with gated activations and dropout.
-
- Args:
- d_model (`int`):
- Size of the input hidden states.
- d_ff (`int`):
- Size of the intermediate feed-forward layer.
- dropout_rate (`float`):
- Dropout probability.
- """
-
- def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
+ def __init__(self, d_model, d_ff, dropout_rate):
super().__init__()
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
@@ -361,7 +260,7 @@ def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
self.dropout = nn.Dropout(dropout_rate)
self.act = NewGELUActivation()
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ def forward(self, hidden_states):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
@@ -372,17 +271,7 @@ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
class T5LayerNorm(nn.Module):
- r"""
- T5 style layer normalization module.
-
- Args:
- hidden_size (`int`):
- Size of the input hidden states.
- eps (`float`, `optional`, defaults to `1e-6`):
- A small value used for numerical stability to avoid dividing by zero.
- """
-
- def __init__(self, hidden_size: int, eps: float = 1e-6):
+ def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
@@ -390,7 +279,7 @@ def __init__(self, hidden_size: int, eps: float = 1e-6):
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ def forward(self, hidden_states):
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
@@ -418,20 +307,14 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
class T5FiLMLayer(nn.Module):
"""
- T5 style FiLM Layer.
-
- Args:
- in_features (`int`):
- Number of input features.
- out_features (`int`):
- Number of output features.
+ FiLM Layer
"""
- def __init__(self, in_features: int, out_features: int):
+ def __init__(self, in_features, out_features):
super().__init__()
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
- def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
+ def forward(self, x, conditioning_emb):
emb = self.scale_bias(conditioning_emb)
scale, shift = torch.chunk(emb, 2, -1)
x = x * (1 + scale) + shift
diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py
index 3aecc43f0f5b..0f00932f3014 100644
--- a/src/diffusers/models/transformer_2d.py
+++ b/src/diffusers/models/transformer_2d.py
@@ -20,12 +20,11 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings
-from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
+from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
from .attention import BasicTransformerBlock
-from .embeddings import CaptionProjection, PatchEmbed
+from .embeddings import PatchEmbed
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin
-from .normalization import AdaLayerNormSingle
@dataclass
@@ -70,8 +69,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
- _supports_gradient_checkpointing = True
-
@register_to_config
def __init__(
self,
@@ -95,9 +92,7 @@ def __init__(
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
- norm_eps: float = 1e-5,
attention_type: str = "default",
- caption_channels: int = None,
):
super().__init__()
self.use_linear_projection = use_linear_projection
@@ -169,15 +164,12 @@ def __init__(
self.width = sample_size
self.patch_size = patch_size
- interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
- interpolation_scale = max(interpolation_scale, 1)
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
- interpolation_scale=interpolation_scale,
)
# 3. Define transformers blocks
@@ -197,7 +189,6 @@ def __init__(
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
- norm_eps=norm_eps,
attention_type=attention_type,
)
for d in range(num_layers)
@@ -215,40 +206,18 @@ def __init__(
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
- elif self.is_input_patches and norm_type != "ada_norm_single":
+ elif self.is_input_patches:
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
- elif self.is_input_patches and norm_type == "ada_norm_single":
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
- self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
-
- # 5. PixArt-Alpha blocks.
- self.adaln_single = None
- self.use_additional_conditions = False
- if norm_type == "ada_norm_single":
- self.use_additional_conditions = self.config.sample_size == 128
- # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
- # additional conditions until we find better name
- self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
-
- self.caption_projection = None
- if caption_channels is not None:
- self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
self.gradient_checkpointing = False
- def _set_gradient_checkpointing(self, module, value=False):
- if hasattr(module, "gradient_checkpointing"):
- module.gradient_checkpointing = value
-
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
@@ -345,40 +314,13 @@ def forward(
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
- height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = self.pos_embed(hidden_states)
- if self.adaln_single is not None:
- if self.use_additional_conditions and added_cond_kwargs is None:
- raise ValueError(
- "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
- )
- batch_size = hidden_states.shape[0]
- timestep, embedded_timestep = self.adaln_single(
- timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
- )
-
# 2. Blocks
- if self.caption_projection is not None:
- batch_size = hidden_states.shape[0]
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
-
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
+ block,
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -386,7 +328,7 @@ def custom_forward(*inputs):
timestep,
cross_attention_kwargs,
class_labels,
- **ckpt_kwargs,
+ use_reentrant=False,
)
else:
hidden_states = block(
@@ -425,26 +367,17 @@ def custom_forward(*inputs):
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
-
- if self.is_input_patches:
- if self.config.norm_type != "ada_norm_single":
- conditioning = self.transformer_blocks[0].norm1.emb(
- timestep, class_labels, hidden_dtype=hidden_states.dtype
- )
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
- hidden_states = self.proj_out_2(hidden_states)
- elif self.config.norm_type == "ada_norm_single":
- shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
- hidden_states = self.norm_out(hidden_states)
- # Modulation
- hidden_states = hidden_states * (1 + scale) + shift
- hidden_states = self.proj_out(hidden_states)
- hidden_states = hidden_states.squeeze(1)
+ elif self.is_input_patches:
+ # TODO: cleanup!
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
# unpatchify
- if self.adaln_single is None:
- height = width = int(hidden_states.shape[1] ** 0.5)
+ height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py
index 26e899a9b908..d002cb3315fa 100644
--- a/src/diffusers/models/transformer_temporal.py
+++ b/src/diffusers/models/transformer_temporal.py
@@ -12,17 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
-from typing import Any, Dict, Optional
+from typing import Optional
import torch
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
-from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
-from .embeddings import TimestepEmbedding, Timesteps
+from .attention import BasicTransformerBlock
from .modeling_utils import ModelMixin
-from .resnet import AlphaBlender
@dataclass
@@ -50,21 +48,13 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
- attention_bias (`bool`, *optional*):
- Configure if the `TransformerBlock` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
- activation_fn (`str`, *optional*, defaults to `"geglu"`):
- Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
- activation functions.
- norm_elementwise_affine (`bool`, *optional*):
- Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
- positional_embeddings: (`str`, *optional*):
- The type of positional embeddings to apply to the sequence input before passing use.
- num_positional_embeddings: (`int`, *optional*):
- The maximum length of the sequence over which to apply positional embeddings.
"""
@register_to_config
@@ -83,8 +73,6 @@ def __init__(
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
double_self_attention: bool = True,
- positional_embeddings: Optional[str] = None,
- num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
@@ -109,8 +97,6 @@ def __init__(
attention_bias=attention_bias,
double_self_attention=double_self_attention,
norm_elementwise_affine=norm_elementwise_affine,
- positional_embeddings=positional_embeddings,
- num_positional_embeddings=num_positional_embeddings,
)
for d in range(num_layers)
]
@@ -120,14 +106,14 @@ def __init__(
def forward(
self,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: Optional[torch.LongTensor] = None,
- timestep: Optional[torch.LongTensor] = None,
- class_labels: torch.LongTensor = None,
- num_frames: int = 1,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ hidden_states,
+ encoder_hidden_states=None,
+ timestep=None,
+ class_labels=None,
+ num_frames=1,
+ cross_attention_kwargs=None,
return_dict: bool = True,
- ) -> TransformerTemporalModelOutput:
+ ):
"""
The [`TransformerTemporal`] forward method.
@@ -137,7 +123,7 @@ def forward(
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
- timestep ( `torch.LongTensor`, *optional*):
+ timestep ( `torch.long`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
@@ -185,7 +171,7 @@ def forward(
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states[None, None, :]
- .reshape(batch_size, height, width, num_frames, channel)
+ .reshape(batch_size, height, width, channel, num_frames)
.permute(0, 3, 4, 1, 2)
.contiguous()
)
@@ -197,183 +183,3 @@ def forward(
return (output,)
return TransformerTemporalModelOutput(sample=output)
-
-
-class TransformerSpatioTemporalModel(nn.Module):
- """
- A Transformer model for video-like data.
-
- Parameters:
- num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
- attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
- in_channels (`int`, *optional*):
- The number of channels in the input and output (specify if the input is **continuous**).
- out_channels (`int`, *optional*):
- The number of channels in the output (specify if the input is **continuous**).
- num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
- """
-
- def __init__(
- self,
- num_attention_heads: int = 16,
- attention_head_dim: int = 88,
- in_channels: int = 320,
- out_channels: Optional[int] = None,
- num_layers: int = 1,
- cross_attention_dim: Optional[int] = None,
- ):
- super().__init__()
- self.num_attention_heads = num_attention_heads
- self.attention_head_dim = attention_head_dim
-
- inner_dim = num_attention_heads * attention_head_dim
- self.inner_dim = inner_dim
-
- # 2. Define input layers
- self.in_channels = in_channels
- self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
- self.proj_in = nn.Linear(in_channels, inner_dim)
-
- # 3. Define transformers blocks
- self.transformer_blocks = nn.ModuleList(
- [
- BasicTransformerBlock(
- inner_dim,
- num_attention_heads,
- attention_head_dim,
- cross_attention_dim=cross_attention_dim,
- )
- for d in range(num_layers)
- ]
- )
-
- time_mix_inner_dim = inner_dim
- self.temporal_transformer_blocks = nn.ModuleList(
- [
- TemporalBasicTransformerBlock(
- inner_dim,
- time_mix_inner_dim,
- num_attention_heads,
- attention_head_dim,
- cross_attention_dim=cross_attention_dim,
- )
- for _ in range(num_layers)
- ]
- )
-
- time_embed_dim = in_channels * 4
- self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
- self.time_proj = Timesteps(in_channels, True, 0)
- self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
-
- # 4. Define output layers
- self.out_channels = in_channels if out_channels is None else out_channels
- # TODO: should use out_channels for continuous projections
- self.proj_out = nn.Linear(inner_dim, in_channels)
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- image_only_indicator: Optional[torch.Tensor] = None,
- return_dict: bool = True,
- ):
- """
- Args:
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
- Input hidden_states.
- num_frames (`int`):
- The number of frames to be processed per batch. This is used to reshape the hidden states.
- encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
- self-attention.
- image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
- A tensor indicating whether the input contains only images. 1 indicates that the input contains only
- images, 0 indicates that the input contains video frames.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
- tuple.
-
- Returns:
- [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
- If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
- returned, otherwise a `tuple` where the first element is the sample tensor.
- """
- # 1. Input
- batch_frames, _, height, width = hidden_states.shape
- num_frames = image_only_indicator.shape[-1]
- batch_size = batch_frames // num_frames
-
- time_context = encoder_hidden_states
- time_context_first_timestep = time_context[None, :].reshape(
- batch_size, num_frames, -1, time_context.shape[-1]
- )[:, 0]
- time_context = time_context_first_timestep[None, :].broadcast_to(
- height * width, batch_size, 1, time_context.shape[-1]
- )
- time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
-
- residual = hidden_states
-
- hidden_states = self.norm(hidden_states)
- inner_dim = hidden_states.shape[1]
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
- hidden_states = self.proj_in(hidden_states)
-
- num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
- num_frames_emb = num_frames_emb.repeat(batch_size, 1)
- num_frames_emb = num_frames_emb.reshape(-1)
- t_emb = self.time_proj(num_frames_emb)
-
- # `Timesteps` does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=hidden_states.dtype)
-
- emb = self.time_pos_embed(t_emb)
- emb = emb[:, None, :]
-
- # 2. Blocks
- for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
- if self.training and self.gradient_checkpointing:
- hidden_states = torch.utils.checkpoint.checkpoint(
- block,
- hidden_states,
- None,
- encoder_hidden_states,
- None,
- use_reentrant=False,
- )
- else:
- hidden_states = block(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- )
-
- hidden_states_mix = hidden_states
- hidden_states_mix = hidden_states_mix + emb
-
- hidden_states_mix = temporal_block(
- hidden_states_mix,
- num_frames=num_frames,
- encoder_hidden_states=time_context,
- )
- hidden_states = self.time_mixer(
- x_spatial=hidden_states,
- x_temporal=hidden_states_mix,
- image_only_indicator=image_only_indicator,
- )
-
- # 3. Output
- hidden_states = self.proj_out(hidden_states)
- hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
-
- output = hidden_states + residual
-
- if not return_dict:
- return (output,)
-
- return TransformerTemporalModelOutput(sample=output)
diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py
index 74a2f1681ead..84ae48e0f8c4 100644
--- a/src/diffusers/models/unet_1d_blocks.py
+++ b/src/diffusers/models/unet_1d_blocks.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
-from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
@@ -25,17 +24,17 @@
class DownResnetBlock1D(nn.Module):
def __init__(
self,
- in_channels: int,
- out_channels: Optional[int] = None,
- num_layers: int = 1,
- conv_shortcut: bool = False,
- temb_channels: int = 32,
- groups: int = 32,
- groups_out: Optional[int] = None,
- non_linearity: Optional[str] = None,
- time_embedding_norm: str = "default",
- output_scale_factor: float = 1.0,
- add_downsample: bool = True,
+ in_channels,
+ out_channels=None,
+ num_layers=1,
+ conv_shortcut=False,
+ temb_channels=32,
+ groups=32,
+ groups_out=None,
+ non_linearity=None,
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ add_downsample=True,
):
super().__init__()
self.in_channels = in_channels
@@ -66,7 +65,7 @@ def __init__(
if add_downsample:
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ def forward(self, hidden_states, temb=None):
output_states = ()
hidden_states = self.resnets[0](hidden_states, temb)
@@ -87,16 +86,16 @@ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTe
class UpResnetBlock1D(nn.Module):
def __init__(
self,
- in_channels: int,
- out_channels: Optional[int] = None,
- num_layers: int = 1,
- temb_channels: int = 32,
- groups: int = 32,
- groups_out: Optional[int] = None,
- non_linearity: Optional[str] = None,
- time_embedding_norm: str = "default",
- output_scale_factor: float = 1.0,
- add_upsample: bool = True,
+ in_channels,
+ out_channels=None,
+ num_layers=1,
+ temb_channels=32,
+ groups=32,
+ groups_out=None,
+ non_linearity=None,
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ add_upsample=True,
):
super().__init__()
self.in_channels = in_channels
@@ -126,12 +125,7 @@ def __init__(
if add_upsample:
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Optional[Tuple[torch.FloatTensor, ...]] = None,
- temb: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
if res_hidden_states_tuple is not None:
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
@@ -150,7 +144,7 @@ def forward(
class ValueFunctionMidBlock1D(nn.Module):
- def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
+ def __init__(self, in_channels, out_channels, embed_dim):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@@ -161,7 +155,7 @@ def __init__(self, in_channels: int, out_channels: int, embed_dim: int):
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
- def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ def forward(self, x, temb=None):
x = self.res1(x, temb)
x = self.down1(x)
x = self.res2(x, temb)
@@ -172,13 +166,13 @@ def forward(self, x: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
class MidResTemporalBlock1D(nn.Module):
def __init__(
self,
- in_channels: int,
- out_channels: int,
- embed_dim: int,
+ in_channels,
+ out_channels,
+ embed_dim,
num_layers: int = 1,
add_downsample: bool = False,
add_upsample: bool = False,
- non_linearity: Optional[str] = None,
+ non_linearity=None,
):
super().__init__()
self.in_channels = in_channels
@@ -209,7 +203,7 @@ def __init__(
if self.upsample and self.downsample:
raise ValueError("Block cannot downsample and upsample")
- def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
+ def forward(self, hidden_states, temb):
hidden_states = self.resnets[0](hidden_states, temb)
for resnet in self.resnets[1:]:
hidden_states = resnet(hidden_states, temb)
@@ -223,14 +217,14 @@ def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) ->
class OutConv1DBlock(nn.Module):
- def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str):
+ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
super().__init__()
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
self.final_conv1d_act = get_activation(act_fn)
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ def forward(self, hidden_states, temb=None):
hidden_states = self.final_conv1d_1(hidden_states)
hidden_states = rearrange_dims(hidden_states)
hidden_states = self.final_conv1d_gn(hidden_states)
@@ -241,7 +235,7 @@ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTe
class OutValueFunctionBlock(nn.Module):
- def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
+ def __init__(self, fc_dim, embed_dim, act_fn="mish"):
super().__init__()
self.final_block = nn.ModuleList(
[
@@ -251,7 +245,7 @@ def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"):
]
)
- def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
+ def forward(self, hidden_states, temb):
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
hidden_states = torch.cat((hidden_states, temb), dim=-1)
for layer in self.final_block:
@@ -281,14 +275,14 @@ def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor) ->
class Downsample1d(nn.Module):
- def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
+ def __init__(self, kernel="linear", pad_mode="reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ def forward(self, hidden_states):
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
@@ -298,14 +292,14 @@ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
class Upsample1d(nn.Module):
- def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"):
+ def __init__(self, kernel="linear", pad_mode="reflect"):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer("kernel", kernel_1d)
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ def forward(self, hidden_states, temb=None):
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
@@ -315,7 +309,7 @@ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTe
class SelfAttention1d(nn.Module):
- def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0):
+ def __init__(self, in_channels, n_head=1, dropout_rate=0.0):
super().__init__()
self.channels = in_channels
self.group_norm = nn.GroupNorm(1, num_channels=in_channels)
@@ -335,7 +329,7 @@ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
return new_projection
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ def forward(self, hidden_states):
residual = hidden_states
batch, channel_dim, seq = hidden_states.shape
@@ -373,7 +367,7 @@ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
class ResConvBlock(nn.Module):
- def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False):
+ def __init__(self, in_channels, mid_channels, out_channels, is_last=False):
super().__init__()
self.is_last = is_last
self.has_conv_skip = in_channels != out_channels
@@ -390,7 +384,7 @@ def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_la
self.group_norm_2 = nn.GroupNorm(1, out_channels)
self.gelu_2 = nn.GELU()
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ def forward(self, hidden_states):
residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states
hidden_states = self.conv_1(hidden_states)
@@ -407,7 +401,7 @@ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
class UNetMidBlock1D(nn.Module):
- def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None):
+ def __init__(self, mid_channels, in_channels, out_channels=None):
super().__init__()
out_channels = in_channels if out_channels is None else out_channels
@@ -435,7 +429,7 @@ def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[i
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ def forward(self, hidden_states, temb=None):
hidden_states = self.down(hidden_states)
for attn, resnet in zip(self.attentions, self.resnets):
hidden_states = resnet(hidden_states)
@@ -447,7 +441,7 @@ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTe
class AttnDownBlock1D(nn.Module):
- def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
+ def __init__(self, out_channels, in_channels, mid_channels=None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
@@ -466,7 +460,7 @@ def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[i
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ def forward(self, hidden_states, temb=None):
hidden_states = self.down(hidden_states)
for resnet, attn in zip(self.resnets, self.attentions):
@@ -477,7 +471,7 @@ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTe
class DownBlock1D(nn.Module):
- def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
+ def __init__(self, out_channels, in_channels, mid_channels=None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
@@ -490,7 +484,7 @@ def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[i
self.resnets = nn.ModuleList(resnets)
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ def forward(self, hidden_states, temb=None):
hidden_states = self.down(hidden_states)
for resnet in self.resnets:
@@ -500,7 +494,7 @@ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTe
class DownBlock1DNoSkip(nn.Module):
- def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None):
+ def __init__(self, out_channels, in_channels, mid_channels=None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
@@ -512,7 +506,7 @@ def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[i
self.resnets = nn.ModuleList(resnets)
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ def forward(self, hidden_states, temb=None):
hidden_states = torch.cat([hidden_states, temb], dim=1)
for resnet in self.resnets:
hidden_states = resnet(hidden_states)
@@ -521,7 +515,7 @@ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTe
class AttnUpBlock1D(nn.Module):
- def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
+ def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
mid_channels = out_channels if mid_channels is None else mid_channels
@@ -540,12 +534,7 @@ def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[i
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
@@ -559,7 +548,7 @@ def forward(
class UpBlock1D(nn.Module):
- def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
+ def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
@@ -572,12 +561,7 @@ def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[i
self.resnets = nn.ModuleList(resnets)
self.up = Upsample1d(kernel="cubic")
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
@@ -590,7 +574,7 @@ def forward(
class UpBlock1DNoSkip(nn.Module):
- def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None):
+ def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
mid_channels = in_channels if mid_channels is None else mid_channels
@@ -602,12 +586,7 @@ def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[i
self.resnets = nn.ModuleList(resnets)
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
res_hidden_states = res_hidden_states_tuple[-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
@@ -617,20 +596,7 @@ def forward(
return hidden_states
-DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip]
-MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D]
-OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock]
-UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip]
-
-
-def get_down_block(
- down_block_type: str,
- num_layers: int,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- add_downsample: bool,
-) -> DownBlockType:
+def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
if down_block_type == "DownResnetBlock1D":
return DownResnetBlock1D(
in_channels=in_channels,
@@ -648,9 +614,7 @@ def get_down_block(
raise ValueError(f"{down_block_type} does not exist.")
-def get_up_block(
- up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool
-) -> UpBlockType:
+def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
if up_block_type == "UpResnetBlock1D":
return UpResnetBlock1D(
in_channels=in_channels,
@@ -668,15 +632,7 @@ def get_up_block(
raise ValueError(f"{up_block_type} does not exist.")
-def get_mid_block(
- mid_block_type: str,
- num_layers: int,
- in_channels: int,
- mid_channels: int,
- out_channels: int,
- embed_dim: int,
- add_downsample: bool,
-) -> MidBlockType:
+def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
if mid_block_type == "MidResTemporalBlock1D":
return MidResTemporalBlock1D(
num_layers=num_layers,
@@ -692,9 +648,7 @@ def get_mid_block(
raise ValueError(f"{mid_block_type} does not exist.")
-def get_out_block(
- *, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int
-) -> Optional[OutBlockType]:
+def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
if out_block_type == "OutConv1DBlock":
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
elif out_block_type == "ValueFunction":
diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py
index 0531d8aae783..db6d3a5dce3f 100644
--- a/src/diffusers/models/unet_2d.py
+++ b/src/diffusers/models/unet_2d.py
@@ -117,7 +117,6 @@ def __init__(
add_attention: bool = True,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
- num_train_timesteps: Optional[int] = None,
):
super().__init__()
@@ -145,9 +144,6 @@ def __init__(
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
- elif time_embedding_type == "learned":
- self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
- timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
@@ -295,8 +291,6 @@ def forward(
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
- elif self.class_embedding is None and class_labels is not None:
- raise ValueError("class_embedding needs to be initialized in order to use class conditioning")
# 2. pre-process
skip_sample = sample
diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py
index e404cef224ff..d57949976d30 100644
--- a/src/diffusers/models/unet_2d_blocks.py
+++ b/src/diffusers/models/unet_2d_blocks.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
@@ -21,9 +21,9 @@
from ..utils import is_torch_version, logging
from ..utils.torch_utils import apply_freeu
from .activations import get_activation
+from .attention import AdaGroupNorm
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from .dual_transformer_2d import DualTransformer2DModel
-from .normalization import AdaGroupNorm
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
from .transformer_2d import Transformer2DModel
@@ -32,31 +32,31 @@
def get_down_block(
- down_block_type: str,
- num_layers: int,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- add_downsample: bool,
- resnet_eps: float,
- resnet_act_fn: str,
- transformer_layers_per_block: int = 1,
- num_attention_heads: Optional[int] = None,
- resnet_groups: Optional[int] = None,
- cross_attention_dim: Optional[int] = None,
- downsample_padding: Optional[int] = None,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- resnet_time_scale_shift: str = "default",
- attention_type: str = "default",
- resnet_skip_time_act: bool = False,
- resnet_out_scale_factor: float = 1.0,
- cross_attention_norm: Optional[str] = None,
- attention_head_dim: Optional[int] = None,
- downsample_type: Optional[str] = None,
- dropout: float = 0.0,
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ transformer_layers_per_block=1,
+ num_attention_heads=None,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ attention_type="default",
+ resnet_skip_time_act=False,
+ resnet_out_scale_factor=1.0,
+ cross_attention_norm=None,
+ attention_head_dim=None,
+ downsample_type=None,
+ dropout=0.0,
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
@@ -241,33 +241,33 @@ def get_down_block(
def get_up_block(
- up_block_type: str,
- num_layers: int,
- in_channels: int,
- out_channels: int,
- prev_output_channel: int,
- temb_channels: int,
- add_upsample: bool,
- resnet_eps: float,
- resnet_act_fn: str,
- resolution_idx: Optional[int] = None,
- transformer_layers_per_block: int = 1,
- num_attention_heads: Optional[int] = None,
- resnet_groups: Optional[int] = None,
- cross_attention_dim: Optional[int] = None,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- resnet_time_scale_shift: str = "default",
- attention_type: str = "default",
- resnet_skip_time_act: bool = False,
- resnet_out_scale_factor: float = 1.0,
- cross_attention_norm: Optional[str] = None,
- attention_head_dim: Optional[int] = None,
- upsample_type: Optional[str] = None,
- dropout: float = 0.0,
-) -> nn.Module:
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ resolution_idx=None,
+ transformer_layers_per_block=1,
+ num_attention_heads=None,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ attention_type="default",
+ resnet_skip_time_act=False,
+ resnet_out_scale_factor=1.0,
+ cross_attention_norm=None,
+ attention_head_dim=None,
+ upsample_type=None,
+ dropout=0.0,
+):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
logger.warn(
@@ -498,41 +498,11 @@ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
)
self.fuse = nn.ReLU()
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ def forward(self, x):
return self.fuse(self.conv(x) + self.skip(x))
class UNetMidBlock2D(nn.Module):
- """
- A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
-
- Args:
- in_channels (`int`): The number of input channels.
- temb_channels (`int`): The number of temporal embedding channels.
- dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
- num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
- resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
- resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
- The type of normalization to apply to the time embeddings. This can help to improve the performance of the
- model on tasks with long-range temporal dependencies.
- resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
- resnet_groups (`int`, *optional*, defaults to 32):
- The number of groups to use in the group normalization layers of the resnet blocks.
- attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
- resnet_pre_norm (`bool`, *optional*, defaults to `True`):
- Whether to use pre-normalization for the resnet blocks.
- add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
- attention_head_dim (`int`, *optional*, defaults to 1):
- Dimension of a single attention head. The number of attention heads is determined based on this value and
- the number of input channels.
- output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
-
- Returns:
- `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
- in_channels, height, width)`.
-
- """
-
def __init__(
self,
in_channels: int,
@@ -546,8 +516,8 @@ def __init__(
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
- attention_head_dim: int = 1,
- output_scale_factor: float = 1.0,
+ attention_head_dim=1,
+ output_scale_factor=1.0,
):
super().__init__()
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
@@ -617,7 +587,7 @@ def __init__(
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ def forward(self, hidden_states, temb=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
@@ -634,19 +604,19 @@ def __init__(
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- num_attention_heads: int = 1,
- output_scale_factor: float = 1.0,
- cross_attention_dim: int = 1280,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- upcast_attention: bool = False,
- attention_type: str = "default",
+ num_attention_heads=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ attention_type="default",
):
super().__init__()
@@ -654,10 +624,6 @@ def __init__(
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
- # support for variable transformer layers per block
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * num_layers
-
# there is always at least one resnet
resnets = [
ResnetBlock2D(
@@ -675,14 +641,14 @@ def __init__(
]
attentions = []
- for i in range(num_layers):
+ for _ in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
- num_layers=transformer_layers_per_block[i],
+ num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
@@ -785,12 +751,12 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- attention_head_dim: int = 1,
- output_scale_factor: float = 1.0,
- cross_attention_dim: int = 1280,
- skip_time_act: bool = False,
- only_cross_attention: bool = False,
- cross_attention_norm: Optional[str] = None,
+ attention_head_dim=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ skip_time_act=False,
+ only_cross_attention=False,
+ cross_attention_norm=None,
):
super().__init__()
@@ -866,7 +832,7 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ ):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
@@ -910,10 +876,10 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- attention_head_dim: int = 1,
- output_scale_factor: float = 1.0,
- downsample_padding: int = 1,
- downsample_type: str = "conv",
+ attention_head_dim=1,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ downsample_type="conv",
):
super().__init__()
resnets = []
@@ -989,13 +955,7 @@ def __init__(
else:
self.downsamplers = None
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- upsample_size: Optional[int] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ def forward(self, hidden_states, temb=None, upsample_size=None, cross_attention_kwargs=None):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
@@ -1028,22 +988,22 @@ def __init__(
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- num_attention_heads: int = 1,
- cross_attention_dim: int = 1280,
- output_scale_factor: float = 1.0,
- downsample_padding: int = 1,
- add_downsample: bool = True,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- attention_type: str = "default",
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ attention_type="default",
):
super().__init__()
resnets = []
@@ -1051,8 +1011,6 @@ def __init__(
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
@@ -1076,7 +1034,7 @@ def __init__(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
- num_layers=transformer_layers_per_block[i],
+ num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
@@ -1120,8 +1078,8 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
- additional_residuals: Optional[torch.FloatTensor] = None,
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ additional_residuals=None,
+ ):
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
@@ -1194,9 +1152,9 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- output_scale_factor: float = 1.0,
- add_downsample: bool = True,
- downsample_padding: int = 1,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
):
super().__init__()
resnets = []
@@ -1233,9 +1191,7 @@ def __init__(
self.gradient_checkpointing = False
- def forward(
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ def forward(self, hidden_states, temb=None, scale: float = 1.0):
output_states = ()
for resnet in self.resnets:
@@ -1281,9 +1237,9 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- output_scale_factor: float = 1.0,
- add_downsample: bool = True,
- downsample_padding: int = 1,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
):
super().__init__()
resnets = []
@@ -1318,7 +1274,7 @@ def __init__(
else:
self.downsamplers = None
- def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
+ def forward(self, hidden_states, scale: float = 1.0):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None, scale=scale)
@@ -1341,10 +1297,10 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- attention_head_dim: int = 1,
- output_scale_factor: float = 1.0,
- add_downsample: bool = True,
- downsample_padding: int = 1,
+ attention_head_dim=1,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
):
super().__init__()
resnets = []
@@ -1401,7 +1357,7 @@ def __init__(
else:
self.downsamplers = None
- def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
+ def forward(self, hidden_states, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=None, scale=scale)
cross_attention_kwargs = {"scale": scale}
@@ -1426,9 +1382,9 @@ def __init__(
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
- attention_head_dim: int = 1,
- output_scale_factor: float = np.sqrt(2.0),
- add_downsample: bool = True,
+ attention_head_dim=1,
+ output_scale_factor=np.sqrt(2.0),
+ add_downsample=True,
):
super().__init__()
self.attentions = nn.ModuleList([])
@@ -1495,13 +1451,7 @@ def __init__(
self.downsamplers = None
self.skip_conv = None
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- skip_sample: Optional[torch.FloatTensor] = None,
- scale: float = 1.0,
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
+ def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
@@ -1534,9 +1484,9 @@ def __init__(
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
- output_scale_factor: float = np.sqrt(2.0),
- add_downsample: bool = True,
- downsample_padding: int = 1,
+ output_scale_factor=np.sqrt(2.0),
+ add_downsample=True,
+ downsample_padding=1,
):
super().__init__()
self.resnets = nn.ModuleList([])
@@ -1582,13 +1532,7 @@ def __init__(
self.downsamplers = None
self.skip_conv = None
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- skip_sample: Optional[torch.FloatTensor] = None,
- scale: float = 1.0,
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
+ def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0):
output_states = ()
for resnet in self.resnets:
@@ -1620,9 +1564,9 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- output_scale_factor: float = 1.0,
- add_downsample: bool = True,
- skip_time_act: bool = False,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ skip_time_act=False,
):
super().__init__()
resnets = []
@@ -1671,9 +1615,7 @@ def __init__(
self.gradient_checkpointing = False
- def forward(
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ def forward(self, hidden_states, temb=None, scale: float = 1.0):
output_states = ()
for resnet in self.resnets:
@@ -1720,13 +1662,13 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- attention_head_dim: int = 1,
- cross_attention_dim: int = 1280,
- output_scale_factor: float = 1.0,
- add_downsample: bool = True,
- skip_time_act: bool = False,
- only_cross_attention: bool = False,
- cross_attention_norm: Optional[str] = None,
+ attention_head_dim=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ skip_time_act=False,
+ only_cross_attention=False,
+ cross_attention_norm=None,
):
super().__init__()
@@ -1810,7 +1752,7 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ ):
output_states = ()
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
@@ -1878,7 +1820,7 @@ def __init__(
resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu",
resnet_group_size: int = 32,
- add_downsample: bool = False,
+ add_downsample=False,
):
super().__init__()
resnets = []
@@ -1913,9 +1855,7 @@ def __init__(
self.gradient_checkpointing = False
- def forward(
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ def forward(self, hidden_states, temb=None, scale: float = 1.0):
output_states = ()
for resnet in self.resnets:
@@ -1957,7 +1897,7 @@ def __init__(
dropout: float = 0.0,
num_layers: int = 4,
resnet_group_size: int = 32,
- add_downsample: bool = True,
+ add_downsample=True,
attention_head_dim: int = 64,
add_self_attention: bool = False,
resnet_eps: float = 1e-5,
@@ -2020,7 +1960,7 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ ):
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
@@ -2089,9 +2029,9 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- attention_head_dim: int = 1,
- output_scale_factor: float = 1.0,
- upsample_type: str = "conv",
+ attention_head_dim=1,
+ output_scale_factor=1.0,
+ upsample_type="conv",
):
super().__init__()
resnets = []
@@ -2166,14 +2106,7 @@ def __init__(
self.resolution_idx = resolution_idx
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- upsample_size: Optional[int] = None,
- scale: float = 1.0,
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
@@ -2201,24 +2134,24 @@ def __init__(
out_channels: int,
prev_output_channel: int,
temb_channels: int,
- resolution_idx: Optional[int] = None,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- num_attention_heads: int = 1,
- cross_attention_dim: int = 1280,
- output_scale_factor: float = 1.0,
- add_upsample: bool = True,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- attention_type: str = "default",
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ attention_type="default",
):
super().__init__()
resnets = []
@@ -2227,9 +2160,6 @@ def __init__(
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * num_layers
-
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -2254,7 +2184,7 @@ def __init__(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
- num_layers=transformer_layers_per_block[i],
+ num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
@@ -2295,7 +2225,7 @@ def forward(
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ ):
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = (
getattr(self, "s1", None)
@@ -2374,7 +2304,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
- resolution_idx: Optional[int] = None,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2382,8 +2312,8 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- output_scale_factor: float = 1.0,
- add_upsample: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
):
super().__init__()
resnets = []
@@ -2417,14 +2347,7 @@ def __init__(
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- upsample_size: Optional[int] = None,
- scale: float = 1.0,
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -2482,7 +2405,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
- resolution_idx: Optional[int] = None,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2490,9 +2413,9 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- output_scale_factor: float = 1.0,
- add_upsample: bool = True,
- temb_channels: Optional[int] = None,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ temb_channels=None,
):
super().__init__()
resnets = []
@@ -2524,9 +2447,7 @@ def __init__(
self.resolution_idx = resolution_idx
- def forward(
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states, temb=None, scale: float = 1.0):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
@@ -2542,7 +2463,7 @@ def __init__(
self,
in_channels: int,
out_channels: int,
- resolution_idx: Optional[int] = None,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2550,10 +2471,10 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- attention_head_dim: int = 1,
- output_scale_factor: float = 1.0,
- add_upsample: bool = True,
- temb_channels: Optional[int] = None,
+ attention_head_dim=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ temb_channels=None,
):
super().__init__()
resnets = []
@@ -2608,9 +2529,7 @@ def __init__(
self.resolution_idx = resolution_idx
- def forward(
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states, temb=None, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
cross_attention_kwargs = {"scale": scale}
@@ -2630,16 +2549,16 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
- resolution_idx: Optional[int] = None,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
- attention_head_dim: int = 1,
- output_scale_factor: float = np.sqrt(2.0),
- add_upsample: bool = True,
+ attention_head_dim=1,
+ output_scale_factor=np.sqrt(2.0),
+ add_upsample=True,
):
super().__init__()
self.attentions = nn.ModuleList([])
@@ -2717,14 +2636,7 @@ def __init__(
self.resolution_idx = resolution_idx
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- skip_sample=None,
- scale: float = 1.0,
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
@@ -2760,16 +2672,16 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
- resolution_idx: Optional[int] = None,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
- output_scale_factor: float = np.sqrt(2.0),
- add_upsample: bool = True,
- upsample_padding: int = 1,
+ output_scale_factor=np.sqrt(2.0),
+ add_upsample=True,
+ upsample_padding=1,
):
super().__init__()
self.resnets = nn.ModuleList([])
@@ -2825,14 +2737,7 @@ def __init__(
self.resolution_idx = resolution_idx
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- skip_sample=None,
- scale: float = 1.0,
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
@@ -2865,7 +2770,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
- resolution_idx: Optional[int] = None,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2873,9 +2778,9 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- output_scale_factor: float = 1.0,
- add_upsample: bool = True,
- skip_time_act: bool = False,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ skip_time_act=False,
):
super().__init__()
resnets = []
@@ -2927,14 +2832,7 @@ def __init__(
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- upsample_size: Optional[int] = None,
- scale: float = 1.0,
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
@@ -2974,7 +2872,7 @@ def __init__(
out_channels: int,
prev_output_channel: int,
temb_channels: int,
- resolution_idx: Optional[int] = None,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -2982,13 +2880,13 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- attention_head_dim: int = 1,
- cross_attention_dim: int = 1280,
- output_scale_factor: float = 1.0,
- add_upsample: bool = True,
- skip_time_act: bool = False,
- only_cross_attention: bool = False,
- cross_attention_norm: Optional[str] = None,
+ attention_head_dim=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ skip_time_act=False,
+ only_cross_attention=False,
+ cross_attention_norm=None,
):
super().__init__()
resnets = []
@@ -3076,7 +2974,7 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ ):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
@@ -3145,7 +3043,7 @@ def __init__(
resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu",
resnet_group_size: Optional[int] = 32,
- add_upsample: bool = True,
+ add_upsample=True,
):
super().__init__()
resnets = []
@@ -3183,14 +3081,7 @@ def __init__(
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- upsample_size: Optional[int] = None,
- scale: float = 1.0,
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
res_hidden_states_tuple = res_hidden_states_tuple[-1]
if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
@@ -3234,7 +3125,7 @@ def __init__(
resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu",
resnet_group_size: int = 32,
- attention_head_dim: int = 1, # attention dim_head
+ attention_head_dim=1, # attention dim_head
cross_attention_dim: int = 768,
add_upsample: bool = True,
upcast_attention: bool = False,
@@ -3318,7 +3209,7 @@ def forward(
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ ):
res_hidden_states_tuple = res_hidden_states_tuple[-1]
if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
@@ -3380,18 +3271,11 @@ class KAttentionBlock(nn.Module):
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
- attention_bias (`bool`, *optional*, defaults to `False`):
- Configure if the attention layers should contain a bias parameter.
- upcast_attention (`bool`, *optional*, defaults to `False`):
- Set to `True` to upcast the attention computation to `float32`.
- temb_channels (`int`, *optional*, defaults to 768):
- The number of channels in the token embedding.
- add_self_attention (`bool`, *optional*, defaults to `False`):
- Set to `True` to add self-attention to the block.
- cross_attention_norm (`str`, *optional*, defaults to `None`):
- The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
- group_size (`int`, *optional*, defaults to 32):
- The number of groups to separate the channels into for group normalization.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
"""
def __init__(
@@ -3437,10 +3321,10 @@ def __init__(
cross_attention_norm=cross_attention_norm,
)
- def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
+ def _to_3d(self, hidden_states, height, weight):
return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
- def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
+ def _to_4d(self, hidden_states, height, weight):
return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
def forward(
@@ -3453,7 +3337,7 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ ):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 1. Self-Attention
diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py
index 8cf2f8eb24b4..eb3831aa707e 100644
--- a/src/diffusers/models/unet_2d_blocks_flax.py
+++ b/src/diffusers/models/unet_2d_blocks_flax.py
@@ -45,7 +45,6 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
-
in_channels: int
out_channels: int
dropout: float = 0.0
@@ -126,7 +125,6 @@ class FlaxDownBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
-
in_channels: int
out_channels: int
dropout: float = 0.0
@@ -192,7 +190,6 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
-
in_channels: int
out_channels: int
prev_output_channel: int
@@ -278,7 +275,6 @@ class FlaxUpBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
-
in_channels: int
out_channels: int
prev_output_channel: int
@@ -343,7 +339,6 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
-
in_channels: int
dropout: float = 0.0
num_layers: int = 1
diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py
index dd91d8007229..4039fbfcc67a 100644
--- a/src/diffusers/models/unet_2d_condition.py
+++ b/src/diffusers/models/unet_2d_condition.py
@@ -20,7 +20,7 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
-from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from .activations import get_activation
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -43,7 +43,6 @@
)
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
- UNetMidBlock2D,
UNetMidBlock2DCrossAttn,
UNetMidBlock2DSimpleCrossAttn,
get_down_block,
@@ -87,7 +86,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
- Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
@@ -106,15 +105,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
- transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
- reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
- blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
@@ -148,9 +142,9 @@ class conditioning with `class_embed_type` equal to `None`.
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding.
- conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
- *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
- *optional*): The dimension of the `class_labels` input when
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
@@ -190,8 +184,7 @@ def __init__(
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
- transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
- reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
@@ -272,10 +265,6 @@ def __init__(
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
- if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
- for layer_number_per_block in transformer_layers_per_block:
- if isinstance(layer_number_per_block, list):
- raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
# input
conv_in_padding = (conv_in_kernel - 1) // 2
@@ -511,19 +500,6 @@ def __init__(
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
- elif mid_block_type == "UNetMidBlock2D":
- self.mid_block = UNetMidBlock2D(
- in_channels=block_out_channels[-1],
- temb_channels=blocks_time_embed_dim,
- dropout=dropout,
- num_layers=0,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_groups=norm_num_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- add_attention=False,
- )
elif mid_block_type is None:
self.mid_block = None
else:
@@ -537,11 +513,7 @@ def __init__(
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
- reversed_transformer_layers_per_block = (
- list(reversed(transformer_layers_per_block))
- if reverse_transformer_layers_per_block is None
- else reverse_transformer_layers_per_block
- )
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
@@ -791,7 +763,7 @@ def disable_freeu(self):
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
- if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
setattr(upsample_block, k, None)
def forward(
@@ -806,7 +778,6 @@ def forward(
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
- down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
@@ -851,13 +822,6 @@ def forward(
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
- down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
- additional residuals to be added to UNet long skip connections from down blocks to up blocks for
- example from ControlNet side model(s)
- mid_block_additional_residual (`torch.Tensor`, *optional*):
- additional residual to be added to UNet mid block output, for example from ControlNet side model
- down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
- additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
@@ -874,11 +838,9 @@ def forward(
forward_upsample_size = False
upsample_size = None
- for dim in sample.shape[-2:]:
- if dim % default_overall_up_factor != 0:
- # Forward upsample size to force interpolation output size.
- forward_upsample_size = True
- break
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ # Forward upsample size to force interpolation output size.
+ forward_upsample_size = True
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
@@ -1022,15 +984,6 @@ def forward(
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
- if "image_embeds" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
- )
- image_embeds = added_cond_kwargs.get("image_embeds")
- image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
- encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
-
# 2. pre-process
sample = self.conv_in(sample)
@@ -1047,30 +1000,15 @@ def forward(
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
- # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
- is_adapter = down_intrablock_additional_residuals is not None
- # maintain backward compatibility for legacy usage, where
- # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
- # but can only use one or the other
- if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
- deprecate(
- "T2I should not use down_block_additional_residuals",
- "1.3.0",
- "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
- and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
- for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
- standard_warn=False,
- )
- down_intrablock_additional_residuals = down_block_additional_residuals
- is_adapter = True
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# For t2i-adapter CrossAttnDownBlock2D
additional_residuals = {}
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
sample, res_samples = downsample_block(
hidden_states=sample,
@@ -1083,8 +1021,9 @@ def forward(
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
- sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample += down_block_additional_residuals.pop(0)
down_block_res_samples += res_samples
@@ -1101,25 +1040,21 @@ def forward(
# 4. mid
if self.mid_block is not None:
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
- sample = self.mid_block(
- sample,
- emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- encoder_attention_mask=encoder_attention_mask,
- )
- else:
- sample = self.mid_block(sample, emb)
-
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
# To support T2I-Adapter-XL
if (
is_adapter
- and len(down_intrablock_additional_residuals) > 0
- and sample.shape == down_intrablock_additional_residuals[0].shape
+ and len(down_block_additional_residuals) > 0
+ and sample.shape == down_block_additional_residuals[0].shape
):
- sample += down_intrablock_additional_residuals.pop(0)
+ sample += down_block_additional_residuals.pop(0)
if is_controlnet:
sample = sample + mid_block_additional_residual
@@ -1164,7 +1099,7 @@ def forward(
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
- unscale_lora_layers(self, lora_scale)
+ unscale_lora_layers(self)
if not return_dict:
return (sample,)
diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py
index 13f53e16e7ac..770cbf09ccac 100644
--- a/src/diffusers/models/unet_2d_condition_flax.py
+++ b/src/diffusers/models/unet_2d_condition_flax.py
@@ -100,18 +100,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample_size: int = 32
in_channels: int = 4
out_channels: int = 4
- down_block_types: Tuple[str, ...] = (
+ down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
- up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
only_cross_attention: Union[bool, Tuple[bool]] = False
- block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2
- attention_head_dim: Union[int, Tuple[int, ...]] = 8
- num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
+ attention_head_dim: Union[int, Tuple[int]] = 8
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
@@ -120,7 +120,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
freq_shift: int = 0
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
- transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
addition_embed_type_num_heads: int = 64
@@ -158,7 +158,7 @@ def init_weights(self, rng: jax.Array) -> FrozenDict:
}
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
- def setup(self) -> None:
+ def setup(self):
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
@@ -320,15 +320,15 @@ def setup(self) -> None:
def __call__(
self,
- sample: jnp.ndarray,
- timesteps: Union[jnp.ndarray, float, int],
- encoder_hidden_states: jnp.ndarray,
+ sample,
+ timesteps,
+ encoder_hidden_states,
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
- down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None,
- mid_block_additional_residual: Optional[jnp.ndarray] = None,
+ down_block_additional_residuals=None,
+ mid_block_additional_residual=None,
return_dict: bool = True,
train: bool = False,
- ) -> Union[FlaxUNet2DConditionOutput, Tuple[jnp.ndarray]]:
+ ) -> Union[FlaxUNet2DConditionOutput, Tuple]:
r"""
Args:
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py
index e9c505c347b0..180ae0dc1a81 100644
--- a/src/diffusers/models/unet_3d_blocks.py
+++ b/src/diffusers/models/unet_3d_blocks.py
@@ -12,58 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Tuple, Union
-
import torch
from torch import nn
-from ..utils import is_torch_version
from ..utils.torch_utils import apply_freeu
-from .attention import Attention
-from .dual_transformer_2d import DualTransformer2DModel
-from .resnet import (
- Downsample2D,
- ResnetBlock2D,
- SpatioTemporalResBlock,
- TemporalConvLayer,
- Upsample2D,
-)
+from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
from .transformer_2d import Transformer2DModel
-from .transformer_temporal import (
- TransformerSpatioTemporalModel,
- TransformerTemporalModel,
-)
+from .transformer_temporal import TransformerTemporalModel
def get_down_block(
- down_block_type: str,
- num_layers: int,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- add_downsample: bool,
- resnet_eps: float,
- resnet_act_fn: str,
- num_attention_heads: int,
- resnet_groups: Optional[int] = None,
- cross_attention_dim: Optional[int] = None,
- downsample_padding: Optional[int] = None,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = True,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- resnet_time_scale_shift: str = "default",
- temporal_num_attention_heads: int = 8,
- temporal_max_seq_length: int = 32,
- transformer_layers_per_block: int = 1,
-) -> Union[
- "DownBlock3D",
- "CrossAttnDownBlock3D",
- "DownBlockMotion",
- "CrossAttnDownBlockMotion",
- "DownBlockSpatioTemporal",
- "CrossAttnDownBlockSpatioTemporal",
-]:
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ num_attention_heads,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=True,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+):
if down_block_type == "DownBlock3D":
return DownBlock3D(
num_layers=num_layers,
@@ -98,103 +74,29 @@ def get_down_block(
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
- if down_block_type == "DownBlockMotion":
- return DownBlockMotion(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- resnet_time_scale_shift=resnet_time_scale_shift,
- temporal_num_attention_heads=temporal_num_attention_heads,
- temporal_max_seq_length=temporal_max_seq_length,
- )
- elif down_block_type == "CrossAttnDownBlockMotion":
- if cross_attention_dim is None:
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
- return CrossAttnDownBlockMotion(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- downsample_padding=downsample_padding,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- temporal_num_attention_heads=temporal_num_attention_heads,
- temporal_max_seq_length=temporal_max_seq_length,
- )
- elif down_block_type == "DownBlockSpatioTemporal":
- # added for SDV
- return DownBlockSpatioTemporal(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- add_downsample=add_downsample,
- )
- elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
- # added for SDV
- if cross_attention_dim is None:
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
- return CrossAttnDownBlockSpatioTemporal(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- num_layers=num_layers,
- transformer_layers_per_block=transformer_layers_per_block,
- add_downsample=add_downsample,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads,
- )
-
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(
- up_block_type: str,
- num_layers: int,
- in_channels: int,
- out_channels: int,
- prev_output_channel: int,
- temb_channels: int,
- add_upsample: bool,
- resnet_eps: float,
- resnet_act_fn: str,
- num_attention_heads: int,
- resolution_idx: Optional[int] = None,
- resnet_groups: Optional[int] = None,
- cross_attention_dim: Optional[int] = None,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = True,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- resnet_time_scale_shift: str = "default",
- temporal_num_attention_heads: int = 8,
- temporal_cross_attention_dim: Optional[int] = None,
- temporal_max_seq_length: int = 32,
- transformer_layers_per_block: int = 1,
- dropout: float = 0.0,
-) -> Union[
- "UpBlock3D",
- "CrossAttnUpBlock3D",
- "UpBlockMotion",
- "CrossAttnUpBlockMotion",
- "UpBlockSpatioTemporal",
- "CrossAttnUpBlockSpatioTemporal",
-]:
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ num_attention_heads,
+ resolution_idx=None,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=True,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+):
if up_block_type == "UpBlock3D":
return UpBlock3D(
num_layers=num_layers,
@@ -231,74 +133,6 @@ def get_up_block(
resnet_time_scale_shift=resnet_time_scale_shift,
resolution_idx=resolution_idx,
)
- if up_block_type == "UpBlockMotion":
- return UpBlockMotion(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- resolution_idx=resolution_idx,
- temporal_num_attention_heads=temporal_num_attention_heads,
- temporal_max_seq_length=temporal_max_seq_length,
- )
- elif up_block_type == "CrossAttnUpBlockMotion":
- if cross_attention_dim is None:
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
- return CrossAttnUpBlockMotion(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- add_upsample=add_upsample,
- resnet_eps=resnet_eps,
- resnet_act_fn=resnet_act_fn,
- resnet_groups=resnet_groups,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads,
- dual_cross_attention=dual_cross_attention,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- resnet_time_scale_shift=resnet_time_scale_shift,
- resolution_idx=resolution_idx,
- temporal_num_attention_heads=temporal_num_attention_heads,
- temporal_max_seq_length=temporal_max_seq_length,
- )
- elif up_block_type == "UpBlockSpatioTemporal":
- # added for SDV
- return UpBlockSpatioTemporal(
- num_layers=num_layers,
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- resolution_idx=resolution_idx,
- add_upsample=add_upsample,
- )
- elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
- # added for SDV
- if cross_attention_dim is None:
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
- return CrossAttnUpBlockSpatioTemporal(
- in_channels=in_channels,
- out_channels=out_channels,
- prev_output_channel=prev_output_channel,
- temb_channels=temb_channels,
- num_layers=num_layers,
- transformer_layers_per_block=transformer_layers_per_block,
- add_upsample=add_upsample,
- cross_attention_dim=cross_attention_dim,
- num_attention_heads=num_attention_heads,
- resolution_idx=resolution_idx,
- )
-
raise ValueError(f"{up_block_type} does not exist.")
@@ -314,12 +148,12 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- num_attention_heads: int = 1,
- output_scale_factor: float = 1.0,
- cross_attention_dim: int = 1280,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = True,
- upcast_attention: bool = False,
+ num_attention_heads=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=True,
+ upcast_attention=False,
):
super().__init__()
@@ -347,7 +181,6 @@ def __init__(
in_channels,
in_channels,
dropout=0.1,
- norm_num_groups=resnet_groups,
)
]
attentions = []
@@ -395,7 +228,6 @@ def __init__(
in_channels,
in_channels,
dropout=0.1,
- norm_num_groups=resnet_groups,
)
)
@@ -406,13 +238,13 @@ def __init__(
def forward(
self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- num_frames: int = 1,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- ) -> torch.FloatTensor:
+ hidden_states,
+ temb=None,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ num_frames=1,
+ cross_attention_kwargs=None,
+ ):
hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
for attn, temp_attn, resnet, temp_conv in zip(
@@ -425,10 +257,7 @@ def forward(
return_dict=False,
)[0]
hidden_states = temp_attn(
- hidden_states,
- num_frames=num_frames,
- cross_attention_kwargs=cross_attention_kwargs,
- return_dict=False,
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
)[0]
hidden_states = resnet(hidden_states, temb)
hidden_states = temp_conv(hidden_states, num_frames=num_frames)
@@ -449,15 +278,15 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- num_attention_heads: int = 1,
- cross_attention_dim: int = 1280,
- output_scale_factor: float = 1.0,
- downsample_padding: int = 1,
- add_downsample: bool = True,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
):
super().__init__()
resnets = []
@@ -489,7 +318,6 @@ def __init__(
out_channels,
out_channels,
dropout=0.1,
- norm_num_groups=resnet_groups,
)
)
attentions.append(
@@ -524,11 +352,7 @@ def __init__(
self.downsamplers = nn.ModuleList(
[
Downsample2D(
- out_channels,
- use_conv=True,
- out_channels=out_channels,
- padding=downsample_padding,
- name="op",
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
@@ -539,13 +363,13 @@ def __init__(
def forward(
self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- num_frames: int = 1,
- cross_attention_kwargs: Dict[str, Any] = None,
- ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ hidden_states,
+ temb=None,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ num_frames=1,
+ cross_attention_kwargs=None,
+ ):
# TODO(Patrick, William) - attention mask is not used
output_states = ()
@@ -561,10 +385,7 @@ def forward(
return_dict=False,
)[0]
hidden_states = temp_attn(
- hidden_states,
- num_frames=num_frames,
- cross_attention_kwargs=cross_attention_kwargs,
- return_dict=False,
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
)[0]
output_states += (hidden_states,)
@@ -591,9 +412,9 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- output_scale_factor: float = 1.0,
- add_downsample: bool = True,
- downsample_padding: int = 1,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
):
super().__init__()
resnets = []
@@ -620,7 +441,6 @@ def __init__(
out_channels,
out_channels,
dropout=0.1,
- norm_num_groups=resnet_groups,
)
)
@@ -631,11 +451,7 @@ def __init__(
self.downsamplers = nn.ModuleList(
[
Downsample2D(
- out_channels,
- use_conv=True,
- out_channels=out_channels,
- padding=downsample_padding,
- name="op",
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
@@ -644,12 +460,7 @@ def __init__(
self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- num_frames: int = 1,
- ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ def forward(self, hidden_states, temb=None, num_frames=1):
output_states = ()
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
@@ -681,15 +492,15 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- num_attention_heads: int = 1,
- cross_attention_dim: int = 1280,
- output_scale_factor: float = 1.0,
- add_upsample: bool = True,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- resolution_idx: Optional[int] = None,
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resolution_idx=None,
):
super().__init__()
resnets = []
@@ -723,7 +534,6 @@ def __init__(
out_channels,
out_channels,
dropout=0.1,
- norm_num_groups=resnet_groups,
)
)
attentions.append(
@@ -764,15 +574,15 @@ def __init__(
def forward(
self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- upsample_size: Optional[int] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- num_frames: int = 1,
- cross_attention_kwargs: Dict[str, Any] = None,
- ) -> torch.FloatTensor:
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ attention_mask=None,
+ num_frames=1,
+ cross_attention_kwargs=None,
+ ):
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -811,10 +621,7 @@ def forward(
return_dict=False,
)[0]
hidden_states = temp_attn(
- hidden_states,
- num_frames=num_frames,
- cross_attention_kwargs=cross_attention_kwargs,
- return_dict=False,
+ hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs, return_dict=False
)[0]
if self.upsamplers is not None:
@@ -838,9 +645,9 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- output_scale_factor: float = 1.0,
- add_upsample: bool = True,
- resolution_idx: Optional[int] = None,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ resolution_idx=None,
):
super().__init__()
resnets = []
@@ -869,7 +676,6 @@ def __init__(
out_channels,
out_channels,
dropout=0.1,
- norm_num_groups=resnet_groups,
)
)
@@ -884,14 +690,7 @@ def __init__(
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- upsample_size: Optional[int] = None,
- num_frames: int = 1,
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -925,1471 +724,3 @@ def forward(
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
-
-
-class DownBlockMotion(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- output_scale_factor: float = 1.0,
- add_downsample: bool = True,
- downsample_padding: int = 1,
- temporal_num_attention_heads: int = 1,
- temporal_cross_attention_dim: Optional[int] = None,
- temporal_max_seq_length: int = 32,
- ):
- super().__init__()
- resnets = []
- motion_modules = []
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
- motion_modules.append(
- TransformerTemporalModel(
- num_attention_heads=temporal_num_attention_heads,
- in_channels=out_channels,
- norm_num_groups=resnet_groups,
- cross_attention_dim=temporal_cross_attention_dim,
- attention_bias=False,
- activation_fn="geglu",
- positional_embeddings="sinusoidal",
- num_positional_embeddings=temporal_max_seq_length,
- attention_head_dim=out_channels // temporal_num_attention_heads,
- )
- )
-
- self.resnets = nn.ModuleList(resnets)
- self.motion_modules = nn.ModuleList(motion_modules)
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- Downsample2D(
- out_channels,
- use_conv=True,
- out_channels=out_channels,
- padding=downsample_padding,
- name="op",
- )
- ]
- )
- else:
- self.downsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- scale: float = 1.0,
- num_frames: int = 1,
- ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
- output_states = ()
-
- blocks = zip(self.resnets, self.motion_modules)
- for resnet, motion_module in blocks:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- use_reentrant=False,
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb, scale
- )
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(motion_module),
- hidden_states.requires_grad_(),
- temb,
- num_frames,
- )
-
- else:
- hidden_states = resnet(hidden_states, temb, scale=scale)
- hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
-
- output_states = output_states + (hidden_states,)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states, scale=scale)
-
- output_states = output_states + (hidden_states,)
-
- return hidden_states, output_states
-
-
-class CrossAttnDownBlockMotion(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- transformer_layers_per_block: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- num_attention_heads: int = 1,
- cross_attention_dim: int = 1280,
- output_scale_factor: float = 1.0,
- downsample_padding: int = 1,
- add_downsample: bool = True,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- attention_type: str = "default",
- temporal_cross_attention_dim: Optional[int] = None,
- temporal_num_attention_heads: int = 8,
- temporal_max_seq_length: int = 32,
- ):
- super().__init__()
- resnets = []
- attentions = []
- motion_modules = []
-
- self.has_cross_attention = True
- self.num_attention_heads = num_attention_heads
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- if not dual_cross_attention:
- attentions.append(
- Transformer2DModel(
- num_attention_heads,
- out_channels // num_attention_heads,
- in_channels=out_channels,
- num_layers=transformer_layers_per_block,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- attention_type=attention_type,
- )
- )
- else:
- attentions.append(
- DualTransformer2DModel(
- num_attention_heads,
- out_channels // num_attention_heads,
- in_channels=out_channels,
- num_layers=1,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- )
- )
-
- motion_modules.append(
- TransformerTemporalModel(
- num_attention_heads=temporal_num_attention_heads,
- in_channels=out_channels,
- norm_num_groups=resnet_groups,
- cross_attention_dim=temporal_cross_attention_dim,
- attention_bias=False,
- activation_fn="geglu",
- positional_embeddings="sinusoidal",
- num_positional_embeddings=temporal_max_seq_length,
- attention_head_dim=out_channels // temporal_num_attention_heads,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
- self.motion_modules = nn.ModuleList(motion_modules)
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- Downsample2D(
- out_channels,
- use_conv=True,
- out_channels=out_channels,
- padding=downsample_padding,
- name="op",
- )
- ]
- )
- else:
- self.downsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- num_frames: int = 1,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- additional_residuals: Optional[torch.FloatTensor] = None,
- ):
- output_states = ()
-
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
-
- blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
- for i, (resnet, attn, motion_module) in enumerate(blocks):
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- cross_attention_kwargs=cross_attention_kwargs,
- attention_mask=attention_mask,
- encoder_attention_mask=encoder_attention_mask,
- return_dict=False,
- )[0]
- else:
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- cross_attention_kwargs=cross_attention_kwargs,
- attention_mask=attention_mask,
- encoder_attention_mask=encoder_attention_mask,
- return_dict=False,
- )[0]
- hidden_states = motion_module(
- hidden_states,
- num_frames=num_frames,
- )[0]
-
- # apply additional residuals to the output of the last pair of resnet and attention blocks
- if i == len(blocks) - 1 and additional_residuals is not None:
- hidden_states = hidden_states + additional_residuals
-
- output_states = output_states + (hidden_states,)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states, scale=lora_scale)
-
- output_states = output_states + (hidden_states,)
-
- return hidden_states, output_states
-
-
-class CrossAttnUpBlockMotion(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- prev_output_channel: int,
- temb_channels: int,
- resolution_idx: Optional[int] = None,
- dropout: float = 0.0,
- num_layers: int = 1,
- transformer_layers_per_block: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- num_attention_heads: int = 1,
- cross_attention_dim: int = 1280,
- output_scale_factor: float = 1.0,
- add_upsample: bool = True,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- attention_type: str = "default",
- temporal_cross_attention_dim: Optional[int] = None,
- temporal_num_attention_heads: int = 8,
- temporal_max_seq_length: int = 32,
- ):
- super().__init__()
- resnets = []
- attentions = []
- motion_modules = []
-
- self.has_cross_attention = True
- self.num_attention_heads = num_attention_heads
-
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
-
- resnets.append(
- ResnetBlock2D(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- if not dual_cross_attention:
- attentions.append(
- Transformer2DModel(
- num_attention_heads,
- out_channels // num_attention_heads,
- in_channels=out_channels,
- num_layers=transformer_layers_per_block,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- use_linear_projection=use_linear_projection,
- only_cross_attention=only_cross_attention,
- upcast_attention=upcast_attention,
- attention_type=attention_type,
- )
- )
- else:
- attentions.append(
- DualTransformer2DModel(
- num_attention_heads,
- out_channels // num_attention_heads,
- in_channels=out_channels,
- num_layers=1,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- )
- )
- motion_modules.append(
- TransformerTemporalModel(
- num_attention_heads=temporal_num_attention_heads,
- in_channels=out_channels,
- norm_num_groups=resnet_groups,
- cross_attention_dim=temporal_cross_attention_dim,
- attention_bias=False,
- activation_fn="geglu",
- positional_embeddings="sinusoidal",
- num_positional_embeddings=temporal_max_seq_length,
- attention_head_dim=out_channels // temporal_num_attention_heads,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
- self.motion_modules = nn.ModuleList(motion_modules)
-
- if add_upsample:
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
- else:
- self.upsamplers = None
-
- self.gradient_checkpointing = False
- self.resolution_idx = resolution_idx
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- upsample_size: Optional[int] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- num_frames: int = 1,
- ) -> torch.FloatTensor:
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
- is_freeu_enabled = (
- getattr(self, "s1", None)
- and getattr(self, "s2", None)
- and getattr(self, "b1", None)
- and getattr(self, "b2", None)
- )
-
- blocks = zip(self.resnets, self.attentions, self.motion_modules)
- for resnet, attn, motion_module in blocks:
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
-
- # FreeU: Only operate on the first two stages
- if is_freeu_enabled:
- hidden_states, res_hidden_states = apply_freeu(
- self.resolution_idx,
- hidden_states,
- res_hidden_states,
- s1=self.s1,
- s2=self.s2,
- b1=self.b1,
- b2=self.b2,
- )
-
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
-
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- cross_attention_kwargs=cross_attention_kwargs,
- attention_mask=attention_mask,
- encoder_attention_mask=encoder_attention_mask,
- return_dict=False,
- )[0]
- else:
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- cross_attention_kwargs=cross_attention_kwargs,
- attention_mask=attention_mask,
- encoder_attention_mask=encoder_attention_mask,
- return_dict=False,
- )[0]
- hidden_states = motion_module(
- hidden_states,
- num_frames=num_frames,
- )[0]
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
-
- return hidden_states
-
-
-class UpBlockMotion(nn.Module):
- def __init__(
- self,
- in_channels: int,
- prev_output_channel: int,
- out_channels: int,
- temb_channels: int,
- resolution_idx: Optional[int] = None,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- output_scale_factor: float = 1.0,
- add_upsample: bool = True,
- temporal_norm_num_groups: int = 32,
- temporal_cross_attention_dim: Optional[int] = None,
- temporal_num_attention_heads: int = 8,
- temporal_max_seq_length: int = 32,
- ):
- super().__init__()
- resnets = []
- motion_modules = []
-
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
-
- resnets.append(
- ResnetBlock2D(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- motion_modules.append(
- TransformerTemporalModel(
- num_attention_heads=temporal_num_attention_heads,
- in_channels=out_channels,
- norm_num_groups=temporal_norm_num_groups,
- cross_attention_dim=temporal_cross_attention_dim,
- attention_bias=False,
- activation_fn="geglu",
- positional_embeddings="sinusoidal",
- num_positional_embeddings=temporal_max_seq_length,
- attention_head_dim=out_channels // temporal_num_attention_heads,
- )
- )
-
- self.resnets = nn.ModuleList(resnets)
- self.motion_modules = nn.ModuleList(motion_modules)
-
- if add_upsample:
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
- else:
- self.upsamplers = None
-
- self.gradient_checkpointing = False
- self.resolution_idx = resolution_idx
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- upsample_size=None,
- scale: float = 1.0,
- num_frames: int = 1,
- ) -> torch.FloatTensor:
- is_freeu_enabled = (
- getattr(self, "s1", None)
- and getattr(self, "s2", None)
- and getattr(self, "b1", None)
- and getattr(self, "b2", None)
- )
-
- blocks = zip(self.resnets, self.motion_modules)
-
- for resnet, motion_module in blocks:
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
-
- # FreeU: Only operate on the first two stages
- if is_freeu_enabled:
- hidden_states, res_hidden_states = apply_freeu(
- self.resolution_idx,
- hidden_states,
- res_hidden_states,
- s1=self.s1,
- s2=self.s2,
- b1=self.b1,
- b2=self.b2,
- )
-
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
-
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- use_reentrant=False,
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet), hidden_states, temb
- )
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- )
-
- else:
- hidden_states = resnet(hidden_states, temb, scale=scale)
- hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
-
- return hidden_states
-
-
-class UNetMidBlockCrossAttnMotion(nn.Module):
- def __init__(
- self,
- in_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- transformer_layers_per_block: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default",
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- resnet_pre_norm: bool = True,
- num_attention_heads: int = 1,
- output_scale_factor: float = 1.0,
- cross_attention_dim: int = 1280,
- dual_cross_attention: float = False,
- use_linear_projection: float = False,
- upcast_attention: float = False,
- attention_type: str = "default",
- temporal_num_attention_heads: int = 1,
- temporal_cross_attention_dim: Optional[int] = None,
- temporal_max_seq_length: int = 32,
- ):
- super().__init__()
-
- self.has_cross_attention = True
- self.num_attention_heads = num_attention_heads
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
-
- # there is always at least one resnet
- resnets = [
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- ]
- attentions = []
- motion_modules = []
-
- for _ in range(num_layers):
- if not dual_cross_attention:
- attentions.append(
- Transformer2DModel(
- num_attention_heads,
- in_channels // num_attention_heads,
- in_channels=in_channels,
- num_layers=transformer_layers_per_block,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- use_linear_projection=use_linear_projection,
- upcast_attention=upcast_attention,
- attention_type=attention_type,
- )
- )
- else:
- attentions.append(
- DualTransformer2DModel(
- num_attention_heads,
- in_channels // num_attention_heads,
- in_channels=in_channels,
- num_layers=1,
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
- )
- )
- resnets.append(
- ResnetBlock2D(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
- motion_modules.append(
- TransformerTemporalModel(
- num_attention_heads=temporal_num_attention_heads,
- attention_head_dim=in_channels // temporal_num_attention_heads,
- in_channels=in_channels,
- norm_num_groups=resnet_groups,
- cross_attention_dim=temporal_cross_attention_dim,
- attention_bias=False,
- positional_embeddings="sinusoidal",
- num_positional_embeddings=temporal_max_seq_length,
- activation_fn="geglu",
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
- self.motion_modules = nn.ModuleList(motion_modules)
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- num_frames: int = 1,
- ) -> torch.FloatTensor:
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
- hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
-
- blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
- for attn, resnet, motion_module in blocks:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- cross_attention_kwargs=cross_attention_kwargs,
- attention_mask=attention_mask,
- encoder_attention_mask=encoder_attention_mask,
- return_dict=False,
- )[0]
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(motion_module),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- **ckpt_kwargs,
- )
- else:
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- cross_attention_kwargs=cross_attention_kwargs,
- attention_mask=attention_mask,
- encoder_attention_mask=encoder_attention_mask,
- return_dict=False,
- )[0]
- hidden_states = motion_module(
- hidden_states,
- num_frames=num_frames,
- )[0]
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
-
- return hidden_states
-
-
-class MidBlockTemporalDecoder(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- attention_head_dim: int = 512,
- num_layers: int = 1,
- upcast_attention: bool = False,
- ):
- super().__init__()
-
- resnets = []
- attentions = []
- for i in range(num_layers):
- input_channels = in_channels if i == 0 else out_channels
- resnets.append(
- SpatioTemporalResBlock(
- in_channels=input_channels,
- out_channels=out_channels,
- temb_channels=None,
- eps=1e-6,
- temporal_eps=1e-5,
- merge_factor=0.0,
- merge_strategy="learned",
- switch_spatial_to_temporal_mix=True,
- )
- )
-
- attentions.append(
- Attention(
- query_dim=in_channels,
- heads=in_channels // attention_head_dim,
- dim_head=attention_head_dim,
- eps=1e-6,
- upcast_attention=upcast_attention,
- norm_num_groups=32,
- bias=True,
- residual_connection=True,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- image_only_indicator: torch.FloatTensor,
- ):
- hidden_states = self.resnets[0](
- hidden_states,
- image_only_indicator=image_only_indicator,
- )
- for resnet, attn in zip(self.resnets[1:], self.attentions):
- hidden_states = attn(hidden_states)
- hidden_states = resnet(
- hidden_states,
- image_only_indicator=image_only_indicator,
- )
-
- return hidden_states
-
-
-class UpBlockTemporalDecoder(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- num_layers: int = 1,
- add_upsample: bool = True,
- ):
- super().__init__()
- resnets = []
- for i in range(num_layers):
- input_channels = in_channels if i == 0 else out_channels
-
- resnets.append(
- SpatioTemporalResBlock(
- in_channels=input_channels,
- out_channels=out_channels,
- temb_channels=None,
- eps=1e-6,
- temporal_eps=1e-5,
- merge_factor=0.0,
- merge_strategy="learned",
- switch_spatial_to_temporal_mix=True,
- )
- )
- self.resnets = nn.ModuleList(resnets)
-
- if add_upsample:
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
- else:
- self.upsamplers = None
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- image_only_indicator: torch.FloatTensor,
- ) -> torch.FloatTensor:
- for resnet in self.resnets:
- hidden_states = resnet(
- hidden_states,
- image_only_indicator=image_only_indicator,
- )
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states)
-
- return hidden_states
-
-
-class UNetMidBlockSpatioTemporal(nn.Module):
- def __init__(
- self,
- in_channels: int,
- temb_channels: int,
- num_layers: int = 1,
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
- num_attention_heads: int = 1,
- cross_attention_dim: int = 1280,
- ):
- super().__init__()
-
- self.has_cross_attention = True
- self.num_attention_heads = num_attention_heads
-
- # support for variable transformer layers per block
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * num_layers
-
- # there is always at least one resnet
- resnets = [
- SpatioTemporalResBlock(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=1e-5,
- )
- ]
- attentions = []
-
- for i in range(num_layers):
- attentions.append(
- TransformerSpatioTemporalModel(
- num_attention_heads,
- in_channels // num_attention_heads,
- in_channels=in_channels,
- num_layers=transformer_layers_per_block[i],
- cross_attention_dim=cross_attention_dim,
- )
- )
-
- resnets.append(
- SpatioTemporalResBlock(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=1e-5,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- image_only_indicator: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- hidden_states = self.resnets[0](
- hidden_states,
- temb,
- image_only_indicator=image_only_indicator,
- )
-
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
- if self.training and self.gradient_checkpointing: # TODO
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- image_only_indicator=image_only_indicator,
- return_dict=False,
- )[0]
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- **ckpt_kwargs,
- )
- else:
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- image_only_indicator=image_only_indicator,
- return_dict=False,
- )[0]
- hidden_states = resnet(
- hidden_states,
- temb,
- image_only_indicator=image_only_indicator,
- )
-
- return hidden_states
-
-
-class DownBlockSpatioTemporal(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- num_layers: int = 1,
- add_downsample: bool = True,
- ):
- super().__init__()
- resnets = []
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- SpatioTemporalResBlock(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=1e-5,
- )
- )
-
- self.resnets = nn.ModuleList(resnets)
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- Downsample2D(
- out_channels,
- use_conv=True,
- out_channels=out_channels,
- name="op",
- )
- ]
- )
- else:
- self.downsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- image_only_indicator: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
- output_states = ()
- for resnet in self.resnets:
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- use_reentrant=False,
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- )
- else:
- hidden_states = resnet(
- hidden_states,
- temb,
- image_only_indicator=image_only_indicator,
- )
-
- output_states = output_states + (hidden_states,)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states)
-
- output_states = output_states + (hidden_states,)
-
- return hidden_states, output_states
-
-
-class CrossAttnDownBlockSpatioTemporal(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- temb_channels: int,
- num_layers: int = 1,
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
- num_attention_heads: int = 1,
- cross_attention_dim: int = 1280,
- add_downsample: bool = True,
- ):
- super().__init__()
- resnets = []
- attentions = []
-
- self.has_cross_attention = True
- self.num_attention_heads = num_attention_heads
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * num_layers
-
- for i in range(num_layers):
- in_channels = in_channels if i == 0 else out_channels
- resnets.append(
- SpatioTemporalResBlock(
- in_channels=in_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=1e-6,
- )
- )
- attentions.append(
- TransformerSpatioTemporalModel(
- num_attention_heads,
- out_channels // num_attention_heads,
- in_channels=out_channels,
- num_layers=transformer_layers_per_block[i],
- cross_attention_dim=cross_attention_dim,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- if add_downsample:
- self.downsamplers = nn.ModuleList(
- [
- Downsample2D(
- out_channels,
- use_conv=True,
- out_channels=out_channels,
- padding=1,
- name="op",
- )
- ]
- )
- else:
- self.downsamplers = None
-
- self.gradient_checkpointing = False
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- temb: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- image_only_indicator: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
- output_states = ()
-
- blocks = list(zip(self.resnets, self.attentions))
- for resnet, attn in blocks:
- if self.training and self.gradient_checkpointing: # TODO
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- **ckpt_kwargs,
- )
-
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- image_only_indicator=image_only_indicator,
- return_dict=False,
- )[0]
- else:
- hidden_states = resnet(
- hidden_states,
- temb,
- image_only_indicator=image_only_indicator,
- )
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- image_only_indicator=image_only_indicator,
- return_dict=False,
- )[0]
-
- output_states = output_states + (hidden_states,)
-
- if self.downsamplers is not None:
- for downsampler in self.downsamplers:
- hidden_states = downsampler(hidden_states)
-
- output_states = output_states + (hidden_states,)
-
- return hidden_states, output_states
-
-
-class UpBlockSpatioTemporal(nn.Module):
- def __init__(
- self,
- in_channels: int,
- prev_output_channel: int,
- out_channels: int,
- temb_channels: int,
- resolution_idx: Optional[int] = None,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- add_upsample: bool = True,
- ):
- super().__init__()
- resnets = []
-
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
-
- resnets.append(
- SpatioTemporalResBlock(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- )
- )
-
- self.resnets = nn.ModuleList(resnets)
-
- if add_upsample:
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
- else:
- self.upsamplers = None
-
- self.gradient_checkpointing = False
- self.resolution_idx = resolution_idx
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- image_only_indicator: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- for resnet in self.resnets:
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
-
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
-
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- use_reentrant=False,
- )
- else:
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- )
- else:
- hidden_states = resnet(
- hidden_states,
- temb,
- image_only_indicator=image_only_indicator,
- )
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states)
-
- return hidden_states
-
-
-class CrossAttnUpBlockSpatioTemporal(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- prev_output_channel: int,
- temb_channels: int,
- resolution_idx: Optional[int] = None,
- num_layers: int = 1,
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
- resnet_eps: float = 1e-6,
- num_attention_heads: int = 1,
- cross_attention_dim: int = 1280,
- add_upsample: bool = True,
- ):
- super().__init__()
- resnets = []
- attentions = []
-
- self.has_cross_attention = True
- self.num_attention_heads = num_attention_heads
-
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * num_layers
-
- for i in range(num_layers):
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
-
- resnets.append(
- SpatioTemporalResBlock(
- in_channels=resnet_in_channels + res_skip_channels,
- out_channels=out_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- )
- )
- attentions.append(
- TransformerSpatioTemporalModel(
- num_attention_heads,
- out_channels // num_attention_heads,
- in_channels=out_channels,
- num_layers=transformer_layers_per_block[i],
- cross_attention_dim=cross_attention_dim,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- if add_upsample:
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
- else:
- self.upsamplers = None
-
- self.gradient_checkpointing = False
- self.resolution_idx = resolution_idx
-
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- image_only_indicator: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- for resnet, attn in zip(self.resnets, self.attentions):
- # pop res hidden states
- res_hidden_states = res_hidden_states_tuple[-1]
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
-
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
-
- if self.training and self.gradient_checkpointing: # TODO
-
- def create_custom_forward(module, return_dict=None):
- def custom_forward(*inputs):
- if return_dict is not None:
- return module(*inputs, return_dict=return_dict)
- else:
- return module(*inputs)
-
- return custom_forward
-
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
- hidden_states = torch.utils.checkpoint.checkpoint(
- create_custom_forward(resnet),
- hidden_states,
- temb,
- image_only_indicator,
- **ckpt_kwargs,
- )
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- image_only_indicator=image_only_indicator,
- return_dict=False,
- )[0]
- else:
- hidden_states = resnet(
- hidden_states,
- temb,
- image_only_indicator=image_only_indicator,
- )
- hidden_states = attn(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- image_only_indicator=image_only_indicator,
- return_dict=False,
- )[0]
-
- if self.upsamplers is not None:
- for upsampler in self.upsamplers:
- hidden_states = upsampler(hidden_states)
-
- return hidden_states
diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py
index 3c76b5aa8452..2ab1d4060e17 100644
--- a/src/diffusers/models/unet_3d_condition.py
+++ b/src/diffusers/models/unet_3d_condition.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -23,7 +22,6 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
-from .activations import get_activation
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
@@ -100,19 +98,14 @@ def __init__(
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
- down_block_types: Tuple[str, ...] = (
+ down_block_types: Tuple[str] = (
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
),
- up_block_types: Tuple[str, ...] = (
- "UpBlock3D",
- "CrossAttnUpBlock3D",
- "CrossAttnUpBlock3D",
- "CrossAttnUpBlock3D",
- ),
- block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -180,7 +173,6 @@ def __init__(
attention_head_dim=attention_head_dim,
in_channels=block_out_channels[0],
num_layers=1,
- norm_num_groups=norm_num_groups,
)
# class embedding
@@ -273,7 +265,7 @@ def __init__(
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
- self.conv_act = get_activation("silu")
+ self.conv_act = nn.SiLU()
else:
self.conv_norm_out = None
self.conv_act = None
@@ -309,7 +301,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
- def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
+ def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
@@ -411,7 +403,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
- def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ def enable_forward_chunking(self, chunk_size=None, dim=0):
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
@@ -467,7 +459,7 @@ def set_default_attn_processor(self):
self.set_attn_processor(processor, _remove_lora=True)
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
+ def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
@@ -502,7 +494,7 @@ def disable_freeu(self):
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
- if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
setattr(upsample_block, k, None)
def forward(
@@ -517,7 +509,7 @@ def forward(
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True,
- ) -> Union[UNet3DConditionOutput, Tuple[torch.FloatTensor]]:
+ ) -> Union[UNet3DConditionOutput, Tuple]:
r"""
The [`UNet3DConditionModel`] forward method.
diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py
index 0049456e2187..36983eefc01f 100644
--- a/src/diffusers/models/vae.py
+++ b/src/diffusers/models/vae.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
-from typing import Optional, Tuple
+from typing import Optional
import numpy as np
import torch
@@ -22,17 +22,12 @@
from ..utils.torch_utils import randn_tensor
from .activations import get_activation
from .attention_processor import SpatialNorm
-from .unet_2d_blocks import (
- AutoencoderTinyBlock,
- UNetMidBlock2D,
- get_down_block,
- get_up_block,
-)
+from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block
@dataclass
class DecoderOutput(BaseOutput):
- r"""
+ """
Output of decoding method.
Args:
@@ -44,39 +39,16 @@ class DecoderOutput(BaseOutput):
class Encoder(nn.Module):
- r"""
- The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
-
- Args:
- in_channels (`int`, *optional*, defaults to 3):
- The number of input channels.
- out_channels (`int`, *optional*, defaults to 3):
- The number of output channels.
- down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
- The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
- options.
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
- The number of output channels for each block.
- layers_per_block (`int`, *optional*, defaults to 2):
- The number of layers per block.
- norm_num_groups (`int`, *optional*, defaults to 32):
- The number of groups for normalization.
- act_fn (`str`, *optional*, defaults to `"silu"`):
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
- double_z (`bool`, *optional*, defaults to `True`):
- Whether to double the number of output channels for the last block.
- """
-
def __init__(
self,
- in_channels: int = 3,
- out_channels: int = 3,
- down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
- block_out_channels: Tuple[int, ...] = (64,),
- layers_per_block: int = 2,
- norm_num_groups: int = 32,
- act_fn: str = "silu",
- double_z: bool = True,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownEncoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ norm_num_groups=32,
+ act_fn="silu",
+ double_z=True,
):
super().__init__()
self.layers_per_block = layers_per_block
@@ -135,9 +107,8 @@ def __init__(
self.gradient_checkpointing = False
- def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
- r"""The forward method of the `Encoder` class."""
-
+ def forward(self, x):
+ sample = x
sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
@@ -181,38 +152,16 @@ def custom_forward(*inputs):
class Decoder(nn.Module):
- r"""
- The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
-
- Args:
- in_channels (`int`, *optional*, defaults to 3):
- The number of input channels.
- out_channels (`int`, *optional*, defaults to 3):
- The number of output channels.
- up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
- The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
- The number of output channels for each block.
- layers_per_block (`int`, *optional*, defaults to 2):
- The number of layers per block.
- norm_num_groups (`int`, *optional*, defaults to 32):
- The number of groups for normalization.
- act_fn (`str`, *optional*, defaults to `"silu"`):
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
- norm_type (`str`, *optional*, defaults to `"group"`):
- The normalization type to use. Can be either `"group"` or `"spatial"`.
- """
-
def __init__(
self,
- in_channels: int = 3,
- out_channels: int = 3,
- up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
- block_out_channels: Tuple[int, ...] = (64,),
- layers_per_block: int = 2,
- norm_num_groups: int = 32,
- act_fn: str = "silu",
- norm_type: str = "group", # group, spatial
+ in_channels=3,
+ out_channels=3,
+ up_block_types=("UpDecoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ norm_num_groups=32,
+ act_fn="silu",
+ norm_type="group", # group, spatial
):
super().__init__()
self.layers_per_block = layers_per_block
@@ -278,13 +227,8 @@ def __init__(
self.gradient_checkpointing = False
- def forward(
- self,
- sample: torch.FloatTensor,
- latent_embeds: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
- r"""The forward method of the `Decoder` class."""
-
+ def forward(self, z, latent_embeds=None):
+ sample = z
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
@@ -299,20 +243,14 @@ def custom_forward(*inputs):
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block),
- sample,
- latent_embeds,
- use_reentrant=False,
+ create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(up_block),
- sample,
- latent_embeds,
- use_reentrant=False,
+ create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
)
else:
# middle
@@ -345,16 +283,6 @@ def custom_forward(*inputs):
class UpSample(nn.Module):
- r"""
- The `UpSample` layer of a variational autoencoder that upsamples its input.
-
- Args:
- in_channels (`int`, *optional*, defaults to 3):
- The number of input channels.
- out_channels (`int`, *optional*, defaults to 3):
- The number of output channels.
- """
-
def __init__(
self,
in_channels: int,
@@ -366,7 +294,6 @@ def __init__(
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
- r"""The forward method of the `UpSample` class."""
x = torch.relu(x)
x = self.deconv(x)
return x
@@ -415,7 +342,6 @@ def __init__(
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
- r"""The forward method of the `MaskConditionEncoder` class."""
out = {}
for l in range(len(self.layers)):
layer = self.layers[l]
@@ -426,38 +352,19 @@ def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
class MaskConditionDecoder(nn.Module):
- r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
- decoder with a conditioner on the mask and masked image.
-
- Args:
- in_channels (`int`, *optional*, defaults to 3):
- The number of input channels.
- out_channels (`int`, *optional*, defaults to 3):
- The number of output channels.
- up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
- The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
- The number of output channels for each block.
- layers_per_block (`int`, *optional*, defaults to 2):
- The number of layers per block.
- norm_num_groups (`int`, *optional*, defaults to 32):
- The number of groups for normalization.
- act_fn (`str`, *optional*, defaults to `"silu"`):
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
- norm_type (`str`, *optional*, defaults to `"group"`):
- The normalization type to use. Can be either `"group"` or `"spatial"`.
- """
+ """The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
+ decoder with a conditioner on the mask and masked image."""
def __init__(
self,
- in_channels: int = 3,
- out_channels: int = 3,
- up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
- block_out_channels: Tuple[int, ...] = (64,),
- layers_per_block: int = 2,
- norm_num_groups: int = 32,
- act_fn: str = "silu",
- norm_type: str = "group", # group, spatial
+ in_channels=3,
+ out_channels=3,
+ up_block_types=("UpDecoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ norm_num_groups=32,
+ act_fn="silu",
+ norm_type="group", # group, spatial
):
super().__init__()
self.layers_per_block = layers_per_block
@@ -530,14 +437,7 @@ def __init__(
self.gradient_checkpointing = False
- def forward(
- self,
- z: torch.FloatTensor,
- image: Optional[torch.FloatTensor] = None,
- mask: Optional[torch.FloatTensor] = None,
- latent_embeds: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
- r"""The forward method of the `MaskConditionDecoder` class."""
+ def forward(self, z, image=None, mask=None, latent_embeds=None):
sample = z
sample = self.conv_in(sample)
@@ -553,10 +453,7 @@ def custom_forward(*inputs):
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.mid_block),
- sample,
- latent_embeds,
- use_reentrant=False,
+ create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
)
sample = sample.to(upscale_dtype)
@@ -564,10 +461,7 @@ def custom_forward(*inputs):
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.condition_encoder),
- masked_image,
- mask,
- use_reentrant=False,
+ create_custom_forward(self.condition_encoder), masked_image, mask, use_reentrant=False
)
# up
@@ -577,10 +471,7 @@ def custom_forward(*inputs):
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint(
- create_custom_forward(up_block),
- sample,
- latent_embeds,
- use_reentrant=False,
+ create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
@@ -595,9 +486,7 @@ def custom_forward(*inputs):
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(self.condition_encoder),
- masked_image,
- mask,
+ create_custom_forward(self.condition_encoder), masked_image, mask
)
# up
@@ -650,14 +539,7 @@ class VectorQuantizer(nn.Module):
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(
- self,
- n_e: int,
- vq_embed_dim: int,
- beta: float,
- remap=None,
- unknown_index: str = "random",
- sane_index_shape: bool = False,
- legacy: bool = True,
+ self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True
):
super().__init__()
self.n_e = n_e
@@ -671,7 +553,6 @@ def __init__(
self.remap = remap
if self.remap is not None:
self.register_buffer("used", torch.tensor(np.load(self.remap)))
- self.used: torch.Tensor
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
@@ -686,7 +567,7 @@ def __init__(
self.sane_index_shape = sane_index_shape
- def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
+ def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
@@ -700,7 +581,7 @@ def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
new[unknown] = self.unknown_index
return new.reshape(ishape)
- def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
+ def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape(ishape[0], -1)
@@ -710,7 +591,7 @@ def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
- def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
+ def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim)
@@ -729,7 +610,7 @@ def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatT
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
- z_q: torch.FloatTensor = z + (z_q - z).detach()
+ z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
@@ -744,7 +625,7 @@ def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatT
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
- def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
+ def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
@@ -752,7 +633,7 @@ def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...])
indices = indices.reshape(-1) # flatten again
# get quantized latent vectors
- z_q: torch.FloatTensor = self.embedding(indices)
+ z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.view(shape)
@@ -763,7 +644,7 @@ def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...])
class DiagonalGaussianDistribution(object):
- def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
+ def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
@@ -778,23 +659,17 @@ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
- self.mean.shape,
- generator=generator,
- device=self.parameters.device,
- dtype=self.parameters.dtype,
+ self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
)
x = self.mean + self.std * sample
return x
- def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
+ def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
- return 0.5 * torch.sum(
- torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
- dim=[1, 2, 3],
- )
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
@@ -805,43 +680,23 @@ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
dim=[1, 2, 3],
)
- def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
+ def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
- return 0.5 * torch.sum(
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
- dim=dims,
- )
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
- def mode(self) -> torch.Tensor:
+ def mode(self):
return self.mean
class EncoderTiny(nn.Module):
- r"""
- The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
-
- Args:
- in_channels (`int`):
- The number of input channels.
- out_channels (`int`):
- The number of output channels.
- num_blocks (`Tuple[int, ...]`):
- Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
- use.
- block_out_channels (`Tuple[int, ...]`):
- The number of output channels for each block.
- act_fn (`str`):
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
- """
-
def __init__(
self,
in_channels: int,
out_channels: int,
- num_blocks: Tuple[int, ...],
- block_out_channels: Tuple[int, ...],
+ num_blocks: int,
+ block_out_channels: int,
act_fn: str,
):
super().__init__()
@@ -853,16 +708,7 @@ def __init__(
if i == 0:
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
else:
- layers.append(
- nn.Conv2d(
- num_channels,
- num_channels,
- kernel_size=3,
- padding=1,
- stride=2,
- bias=False,
- )
- )
+ layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False))
for _ in range(num_block):
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
@@ -872,8 +718,7 @@ def __init__(
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
- r"""The forward method of the `EncoderTiny` class."""
+ def forward(self, x):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
@@ -895,31 +740,12 @@ def custom_forward(*inputs):
class DecoderTiny(nn.Module):
- r"""
- The `DecoderTiny` layer is a simpler version of the `Decoder` layer.
-
- Args:
- in_channels (`int`):
- The number of input channels.
- out_channels (`int`):
- The number of output channels.
- num_blocks (`Tuple[int, ...]`):
- Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
- use.
- block_out_channels (`Tuple[int, ...]`):
- The number of output channels for each block.
- upsampling_scaling_factor (`int`):
- The scaling factor to use for upsampling.
- act_fn (`str`):
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
- """
-
def __init__(
self,
in_channels: int,
out_channels: int,
- num_blocks: Tuple[int, ...],
- block_out_channels: Tuple[int, ...],
+ num_blocks: int,
+ block_out_channels: int,
upsampling_scaling_factor: int,
act_fn: str,
):
@@ -941,21 +767,12 @@ def __init__(
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
conv_out_channel = num_channels if not is_final_block else out_channels
- layers.append(
- nn.Conv2d(
- num_channels,
- conv_out_channel,
- kernel_size=3,
- padding=1,
- bias=is_final_block,
- )
- )
+ layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block))
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
- r"""The forward method of the `DecoderTiny` class."""
+ def forward(self, x):
# Clamp.
x = torch.tanh(x / 3) * 3
diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py
index a1f98e813b89..d2dde2ba197b 100644
--- a/src/diffusers/models/vae_flax.py
+++ b/src/diffusers/models/vae_flax.py
@@ -214,7 +214,6 @@ class FlaxAttentionBlock(nn.Module):
Parameters `dtype`
"""
-
channels: int
num_head_channels: int = None
num_groups: int = 32
@@ -292,7 +291,6 @@ class FlaxDownEncoderBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
-
in_channels: int
out_channels: int
dropout: float = 0.0
@@ -349,7 +347,6 @@ class FlaxUpDecoderBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
-
in_channels: int
out_channels: int
dropout: float = 0.0
@@ -404,7 +401,6 @@ class FlaxUNetMidBlock2D(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
-
in_channels: int
dropout: float = 0.0
num_layers: int = 1
@@ -492,7 +488,6 @@ class FlaxEncoder(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
-
in_channels: int = 3
out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
@@ -605,7 +600,6 @@ class FlaxDecoder(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
parameters `dtype`
"""
-
in_channels: int = 3
out_channels: int = 3
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
@@ -773,7 +767,6 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
The `dtype` of the parameters.
"""
-
in_channels: int = 3
out_channels: int = 3
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py
index f4a6c8fb227f..0c15300af213 100644
--- a/src/diffusers/models/vq_model.py
+++ b/src/diffusers/models/vq_model.py
@@ -53,12 +53,10 @@ class VQModel(ModelMixin, ConfigMixin):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
- layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
- norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers.
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
scaling_factor (`float`, *optional*, defaults to `0.18215`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
@@ -67,8 +65,6 @@ class VQModel(ModelMixin, ConfigMixin):
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
- norm_type (`str`, *optional*, defaults to `"group"`):
- Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
"""
@register_to_config
@@ -76,9 +72,9 @@ def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
- down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
- up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
- block_out_channels: Tuple[int, ...] = (64,),
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 3,
@@ -148,9 +144,7 @@ def decode(
return DecoderOutput(sample=dec)
- def forward(
- self, sample: torch.FloatTensor, return_dict: bool = True
- ) -> Union[DecoderOutput, Tuple[torch.FloatTensor, ...]]:
+ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
The [`VQModel`] forward method.
@@ -164,8 +158,8 @@ def forward(
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
is returned.
"""
-
- h = self.encode(sample).latents
+ x = sample
+ h = self.encode(x).latents
dec = self.decode(h).sample
if not return_dict:
diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py
index 678d2c12cfe1..46e6125a0f55 100644
--- a/src/diffusers/optimization.py
+++ b/src/diffusers/optimization.py
@@ -37,7 +37,7 @@ class SchedulerType(Enum):
PIECEWISE_CONSTANT = "piecewise_constant"
-def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1) -> LambdaLR:
+def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
@@ -53,7 +53,7 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1) -> LambdaL
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
-def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1) -> LambdaLR:
+def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
"""
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
increases linearly between 0 and the initial lr set in the optimizer.
@@ -78,7 +78,7 @@ def lr_lambda(current_step: int):
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
-def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1) -> LambdaLR:
+def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
@@ -120,9 +120,7 @@ def rule_func(steps: int) -> float:
return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
-def get_linear_schedule_with_warmup(
- optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1
-) -> LambdaLR:
+def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
@@ -153,7 +151,7 @@ def lr_lambda(current_step: int):
def get_cosine_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
-) -> LambdaLR:
+):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
@@ -187,7 +185,7 @@ def lr_lambda(current_step):
def get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
-) -> LambdaLR:
+):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
@@ -221,13 +219,8 @@ def lr_lambda(current_step):
def get_polynomial_decay_schedule_with_warmup(
- optimizer: Optimizer,
- num_warmup_steps: int,
- num_training_steps: int,
- lr_end: float = 1e-7,
- power: float = 1.0,
- last_epoch: int = -1,
-) -> LambdaLR:
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
+):
"""
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
@@ -295,7 +288,7 @@ def get_scheduler(
num_cycles: int = 1,
power: float = 1.0,
last_epoch: int = -1,
-) -> LambdaLR:
+):
"""
Unified API to get any scheduler from its name.
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 5bb6a301ca4a..19fe2f72d447 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -17,12 +17,7 @@
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
-_import_structure = {
- "controlnet": [],
- "latent_diffusion": [],
- "stable_diffusion": [],
- "stable_diffusion_xl": [],
-}
+_import_structure = {"stable_diffusion": [], "stable_diffusion_xl": [], "latent_diffusion": [], "controlnet": []}
try:
if not is_torch_available():
@@ -44,11 +39,7 @@
_import_structure["dit"] = ["DiTPipeline"]
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
- _import_structure["pipeline_utils"] = [
- "AudioPipelineOutput",
- "DiffusionPipeline",
- "ImagePipelineOutput",
- ]
+ _import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"]
_import_structure["pndm"] = ["PNDMPipeline"]
_import_structure["repaint"] = ["RePaintPipeline"]
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
@@ -70,11 +61,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
- _import_structure["alt_diffusion"] = [
- "AltDiffusionImg2ImgPipeline",
- "AltDiffusionPipeline",
- ]
- _import_structure["animatediff"] = ["AnimateDiffPipeline"]
+ _import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -122,18 +109,9 @@
"KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline",
]
- _import_structure["kandinsky3"] = [
- "Kandinsky3Img2ImgPipeline",
- "Kandinsky3Pipeline",
- ]
- _import_structure["latent_consistency_models"] = [
- "LatentConsistencyModelImg2ImgPipeline",
- "LatentConsistencyModelPipeline",
- ]
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
- _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_diffusion"].extend(
@@ -165,7 +143,6 @@
]
)
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
- _import_structure["stable_video_diffusion"] = ["StableVideoDiffusionPipeline"]
_import_structure["stable_diffusion_xl"].extend(
[
"StableDiffusionXLImg2ImgPipeline",
@@ -174,14 +151,10 @@
"StableDiffusionXLPipeline",
]
)
- _import_structure["t2i_adapter"] = [
- "StableDiffusionAdapterPipeline",
- "StableDiffusionXLAdapterPipeline",
- ]
+ _import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
_import_structure["text_to_video_synthesis"] = [
"TextToVideoSDPipeline",
"TextToVideoZeroPipeline",
- "TextToVideoZeroSDXLPipeline",
"VideoToVideoSDPipeline",
]
_import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
@@ -235,9 +208,7 @@
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
- from ..utils import (
- dummy_torch_and_transformers_and_k_diffusion_objects,
- )
+ from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else:
@@ -280,10 +251,7 @@
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
else:
- _import_structure["spectrogram_diffusion"] = [
- "MidiProcessor",
- "SpectrogramDiffusionPipeline",
- ]
+ _import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -293,11 +261,7 @@
from ..utils.dummy_pt_objects import * # noqa F403
else:
- from .auto_pipeline import (
- AutoPipelineForImage2Image,
- AutoPipelineForInpainting,
- AutoPipelineForText2Image,
- )
+ from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
from .consistency_models import ConsistencyModelPipeline
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
@@ -305,11 +269,7 @@
from .dit import DiTPipeline
from .latent_diffusion import LDMSuperResolutionPipeline
from .latent_diffusion_uncond import LDMPipeline
- from .pipeline_utils import (
- AudioPipelineOutput,
- DiffusionPipeline,
- ImagePipelineOutput,
- )
+ from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
from .pndm import PNDMPipeline
from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline
@@ -330,13 +290,8 @@
from ..utils.dummy_torch_and_transformers_objects import *
else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
- from .animatediff import AnimateDiffPipeline
from .audioldm import AudioLDMPipeline
- from .audioldm2 import (
- AudioLDM2Pipeline,
- AudioLDM2ProjectionModel,
- AudioLDM2UNet2DConditionModel,
- )
+ from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
from .blip_diffusion import BlipDiffusionPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
@@ -376,18 +331,9 @@
KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline,
)
- from .kandinsky3 import (
- Kandinsky3Img2ImgPipeline,
- Kandinsky3Pipeline,
- )
- from .latent_consistency_models import (
- LatentConsistencyModelImg2ImgPipeline,
- LatentConsistencyModelPipeline,
- )
from .latent_diffusion import LDMTextToImagePipeline
from .musicldm import MusicLDMPipeline
from .paint_by_example import PaintByExamplePipeline
- from .pixart_alpha import PixArtAlphaPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_diffusion import (
@@ -422,15 +368,10 @@
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline,
)
- from .stable_video_diffusion import StableVideoDiffusionPipeline
- from .t2i_adapter import (
- StableDiffusionAdapterPipeline,
- StableDiffusionXLAdapterPipeline,
- )
+ from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
from .text_to_video_synthesis import (
TextToVideoSDPipeline,
TextToVideoZeroPipeline,
- TextToVideoZeroSDXLPipeline,
VideoToVideoSDPipeline,
)
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
@@ -516,10 +457,7 @@
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
else:
- from .spectrogram_diffusion import (
- MidiProcessor,
- SpectrogramDiffusionPipeline,
- )
+ from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
index b5c7aee4b4de..18518cc3783f 100644
--- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
+++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
@@ -17,11 +17,11 @@
import torch
from packaging import version
-from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer
+from transformers import CLIPImageProcessor, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict
-from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...image_processor import VaeImageProcessor
+from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -73,55 +73,8 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
- must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
-
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
-class AltDiffusionPipeline(
- DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
-):
+class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using Alt Diffusion.
@@ -133,7 +86,6 @@ class AltDiffusionPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
@@ -154,11 +106,9 @@ class AltDiffusionPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
- _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -169,7 +119,6 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
- image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -246,7 +195,6 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
- image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -293,7 +241,10 @@ def _encode_prompt(
lora_scale: Optional[float] = None,
**kwargs,
):
- deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecation_message = (
+ "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()`"
+ " instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ )
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
@@ -490,23 +441,10 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
- def encode_image(self, image, device, num_images_per_prompt):
- dtype = next(self.image_encoder.parameters()).dtype
-
- if not isinstance(image, torch.Tensor):
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
-
- image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
-
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
-
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -522,7 +460,10 @@ def run_safety_checker(self, image, device, dtype):
return image, has_nsfw_concept
def decode_latents(self, latents):
- deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecation_message = (
+ "The decode_latents method is deprecated and will be removed in 1.0.0. Please use"
+ " VaeImageProcessor.postprocess(...) instead"
+ )
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
latents = 1 / self.vae.config.scaling_factor * latents
@@ -558,22 +499,17 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
@@ -644,61 +580,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
- """
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
-
- Args:
- timesteps (`torch.Tensor`):
- generate embedding vectors at these timesteps
- embedding_dim (`int`, *optional*, defaults to 512):
- dimension of the embeddings to generate
- dtype:
- data type of the generated embeddings
-
- Returns:
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
- """
- assert len(w.shape) == 1
- w = w * 1000.0
-
- half_dim = embedding_dim // 2
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
- emb = w.to(dtype)[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1))
- assert emb.shape == (w.shape[0], embedding_dim)
- return emb
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def guidance_rescale(self):
- return self._guidance_rescale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -707,7 +588,6 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -716,15 +596,13 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -739,10 +617,6 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -767,12 +641,17 @@ def __call__(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -783,15 +662,6 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
@@ -802,23 +672,6 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
-
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -826,21 +679,9 @@ def __call__(
# 1. Check inputs. Raise error if not correct
self.check_inputs(
- prompt,
- height,
- width,
- callback_steps,
- negative_prompt,
- prompt_embeds,
- negative_prompt_embeds,
- callback_on_step_end_tensor_inputs,
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
- self._guidance_scale = guidance_scale
- self._guidance_rescale = guidance_rescale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -850,37 +691,34 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
- lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
- )
+ lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
-
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
-
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -898,24 +736,12 @@ def __call__(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 6.1 Add image embeds for IP-Adapter
- added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
-
- # 6.2 Optionally get Guidance Scale Embedding
- timestep_cond = None
- if self.unet.config.time_cond_proj_dim is not None:
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
- timestep_cond = self.get_guidance_scale_embedding(
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
- ).to(device=device, dtype=latents.dtype)
-
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
@@ -923,34 +749,22 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- timestep_cond=timestep_cond,
- cross_attention_kwargs=self.cross_attention_kwargs,
- added_cond_kwargs=added_cond_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
- if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -959,9 +773,7 @@ def __call__(
callback(step_idx, t, latents)
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
- 0
- ]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
index 4272fa124755..de8f1071d073 100644
--- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
+++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -19,11 +19,11 @@
import PIL.Image
import torch
from packaging import version
-from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer
+from transformers import CLIPImageProcessor, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -75,20 +75,6 @@
"""
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
@@ -113,54 +99,9 @@ def preprocess(image):
return image
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
- must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
-
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionImg2ImgPipeline(
- DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-guided image-to-image generation using Alt Diffusion.
@@ -173,7 +114,6 @@ class AltDiffusionImg2ImgPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
@@ -194,11 +134,9 @@ class AltDiffusionImg2ImgPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
- _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -209,7 +147,6 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
- image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -286,7 +223,6 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
- image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -304,7 +240,10 @@ def _encode_prompt(
lora_scale: Optional[float] = None,
**kwargs,
):
- deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecation_message = (
+ "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()`"
+ " instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ )
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
@@ -501,23 +440,10 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
- def encode_image(self, image, device, num_images_per_prompt):
- dtype = next(self.image_encoder.parameters()).dtype
-
- if not isinstance(image, torch.Tensor):
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
-
- image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
-
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
-
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -533,7 +459,10 @@ def run_safety_checker(self, image, device, dtype):
return image, has_nsfw_concept
def decode_latents(self, latents):
- deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecation_message = (
+ "The decode_latents method is deprecated and will be removed in 1.0.0. Please use"
+ " VaeImageProcessor.postprocess(...) instead"
+ )
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
latents = 1 / self.vae.config.scaling_factor * latents
@@ -561,30 +490,19 @@ def prepare_extra_step_kwargs(self, generator, eta):
return extra_step_kwargs
def check_inputs(
- self,
- prompt,
- strength,
- callback_steps,
- negative_prompt=None,
- prompt_embeds=None,
- negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -636,18 +554,17 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
else:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective"
+ f" batch size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
init_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
- for i in range(batch_size)
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
- init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = self.vae.config.scaling_factor * init_latents
@@ -704,57 +621,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
- """
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
-
- Args:
- timesteps (`torch.Tensor`):
- generate embedding vectors at these timesteps
- embedding_dim (`int`, *optional*, defaults to 512):
- dimension of the embeddings to generate
- dtype:
- data type of the generated embeddings
-
- Returns:
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
- """
- assert len(w.shape) == 1
- w = w * 1000.0
-
- half_dim = embedding_dim // 2
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
- emb = w.to(dtype)[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1))
- assert emb.shape == (w.shape[0], embedding_dim)
- return emb
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -763,7 +629,6 @@ def __call__(
image: PipelineImageInput = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
- timesteps: List[int] = None,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -771,14 +636,12 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: int = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -801,10 +664,6 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -825,27 +684,23 @@ def __call__(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
@@ -855,37 +710,8 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
-
# 1. Check inputs. Raise error if not correct
- self.check_inputs(
- prompt,
- strength,
- callback_steps,
- negative_prompt,
- prompt_embeds,
- negative_prompt_embeds,
- callback_on_step_end_tensor_inputs,
- )
-
- self._guidance_scale = guidance_scale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
+ self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -894,75 +720,55 @@ def __call__(
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
-
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
-
# 4. Preprocess image
image = self.image_processor.preprocess(image)
# 5. set timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
latents = self.prepare_latents(
- image,
- latent_timestep,
- batch_size,
- num_images_per_prompt,
- prompt_embeds.dtype,
- device,
- generator,
+ image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 7.1 Add image embeds for IP-Adapter
- added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
-
- # 7.2 Optionally get Guidance Scale Embedding
- timestep_cond = None
- if self.unet.config.time_cond_proj_dim is not None:
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
- timestep_cond = self.get_guidance_scale_embedding(
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
- ).to(device=device, dtype=latents.dtype)
-
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
@@ -970,30 +776,18 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- timestep_cond=timestep_cond,
- cross_attention_kwargs=self.cross_attention_kwargs,
- added_cond_kwargs=added_cond_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -1002,9 +796,7 @@ def __call__(
callback(step_idx, t, latents)
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
- 0
- ]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
index 9db3882a15f1..3345fb6e7586 100644
--- a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
+++ b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
@@ -72,7 +72,6 @@ class AudioLDMPipeline(DiffusionPipeline):
vocoder ([`~transformers.SpeechT5HifiGan`]):
Vocoder of class `SpeechT5HifiGan`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
def __init__(
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index a7c6cd82c8e7..13f12e75fb31 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -42,9 +42,6 @@
KandinskyV22InpaintPipeline,
KandinskyV22Pipeline,
)
-from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
-from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
-from .pixart_alpha import PixArtAlphaPipeline
from .stable_diffusion import (
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
@@ -65,12 +62,9 @@
("if", IFPipeline),
("kandinsky", KandinskyCombinedPipeline),
("kandinsky22", KandinskyV22CombinedPipeline),
- ("kandinsky3", Kandinsky3Pipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
("wuerstchen", WuerstchenCombinedPipeline),
- ("lcm", LatentConsistencyModelPipeline),
- ("pixart", PixArtAlphaPipeline),
]
)
@@ -81,10 +75,8 @@
("if", IFImg2ImgPipeline),
("kandinsky", KandinskyImg2ImgCombinedPipeline),
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
- ("kandinsky3", Kandinsky3Img2ImgPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
- ("lcm", LatentConsistencyModelImg2ImgPipeline),
]
)
@@ -184,7 +176,6 @@ class AutoPipelineForText2Image(ConfigMixin):
diffusion pipeline's components.
"""
-
config_name = "model_index.json"
def __init__(self, *args, **kwargs):
@@ -378,7 +369,7 @@ def from_pipe(cls, pipeline, **kwargs):
if kwargs["controlnet"] is not None:
text_2_image_cls = _get_task_class(
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
- text_2_image_cls.__name__.replace("ControlNet", "").replace("Pipeline", "ControlNetPipeline"),
+ text_2_image_cls.__name__.replace("Pipeline", "ControlNetPipeline"),
)
else:
text_2_image_cls = _get_task_class(
@@ -455,7 +446,6 @@ class AutoPipelineForImage2Image(ConfigMixin):
diffusion pipeline's components.
"""
-
config_name = "model_index.json"
def __init__(self, *args, **kwargs):
@@ -652,9 +642,7 @@ def from_pipe(cls, pipeline, **kwargs):
if kwargs["controlnet"] is not None:
image_2_image_cls = _get_task_class(
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
- image_2_image_cls.__name__.replace("ControlNet", "").replace(
- "Img2ImgPipeline", "ControlNetImg2ImgPipeline"
- ),
+ image_2_image_cls.__name__.replace("Img2ImgPipeline", "ControlNetImg2ImgPipeline"),
)
else:
image_2_image_cls = _get_task_class(
@@ -731,7 +719,6 @@ class AutoPipelineForInpainting(ConfigMixin):
diffusion pipeline's components.
"""
-
config_name = "model_index.json"
def __init__(self, *args, **kwargs):
@@ -926,9 +913,7 @@ def from_pipe(cls, pipeline, **kwargs):
if kwargs["controlnet"] is not None:
inpainting_cls = _get_task_class(
AUTO_INPAINT_PIPELINES_MAPPING,
- inpainting_cls.__name__.replace("ControlNet", "").replace(
- "InpaintPipeline", "ControlNetInpaintPipeline"
- ),
+ inpainting_cls.__name__.replace("InpaintPipeline", "ControlNetInpaintPipeline"),
)
else:
inpainting_cls = _get_task_class(
diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py b/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py
index 19f62e789e2d..53d57188743d 100644
--- a/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py
+++ b/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py
@@ -19,21 +19,10 @@
from transformers import CLIPPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.clip.configuration_clip import CLIPTextConfig
-from transformers.models.clip.modeling_clip import CLIPEncoder
-
-
-def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
- """
- Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
- """
- bsz, src_len = mask.size()
- tgt_len = tgt_len if tgt_len is not None else src_len
-
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
-
- inverted_mask = 1.0 - expanded_mask
-
- return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+from transformers.models.clip.modeling_clip import (
+ CLIPEncoder,
+ _expand_mask,
+)
# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip
diff --git a/src/diffusers/pipelines/consistency_models/__init__.py b/src/diffusers/pipelines/consistency_models/__init__.py
index 162d91c010ac..053a3666263f 100644
--- a/src/diffusers/pipelines/consistency_models/__init__.py
+++ b/src/diffusers/pipelines/consistency_models/__init__.py
@@ -6,9 +6,7 @@
)
-_import_structure = {
- "pipeline_consistency_models": ["ConsistencyModelPipeline"],
-}
+_import_structure = {"pipeline_consistency_models": ["ConsistencyModelPipeline"]}
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_consistency_models import ConsistencyModelPipeline
diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
index bf4107568b23..de1b1fd93c7f 100644
--- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
+++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
@@ -1,17 +1,3 @@
-# Copyright 2023 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
from typing import Callable, List, Optional, Union
import torch
@@ -74,7 +60,6 @@ class ConsistencyModelPipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Currently only
compatible with [`CMStochasticIterativeScheduler`].
"""
-
model_cpu_offload_seq = "unet"
def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None:
diff --git a/src/diffusers/pipelines/controlnet/__init__.py b/src/diffusers/pipelines/controlnet/__init__.py
index b1671050c93f..3b832c017064 100644
--- a/src/diffusers/pipelines/controlnet/__init__.py
+++ b/src/diffusers/pipelines/controlnet/__init__.py
@@ -1,80 +1,80 @@
-from typing import TYPE_CHECKING
-
-from ...utils import (
- DIFFUSERS_SLOW_IMPORT,
- OptionalDependencyNotAvailable,
- _LazyModule,
- get_objects_from_module,
- is_flax_available,
- is_torch_available,
- is_transformers_available,
-)
-
-
-_dummy_objects = {}
-_import_structure = {}
-
-try:
- if not (is_transformers_available() and is_torch_available()):
- raise OptionalDependencyNotAvailable()
-except OptionalDependencyNotAvailable:
- from ...utils import dummy_torch_and_transformers_objects # noqa F403
-
- _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
-else:
- _import_structure["multicontrolnet"] = ["MultiControlNetModel"]
- _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
- _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
- _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
- _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
- _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
- _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
- _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
-try:
- if not (is_transformers_available() and is_flax_available()):
- raise OptionalDependencyNotAvailable()
-except OptionalDependencyNotAvailable:
- from ...utils import dummy_flax_and_transformers_objects # noqa F403
-
- _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
-else:
- _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
-
-
-if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
- try:
- if not (is_transformers_available() and is_torch_available()):
- raise OptionalDependencyNotAvailable()
-
- except OptionalDependencyNotAvailable:
- from ...utils.dummy_torch_and_transformers_objects import *
- else:
- from .multicontrolnet import MultiControlNetModel
- from .pipeline_controlnet import StableDiffusionControlNetPipeline
- from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
- from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
- from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
- from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
- from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
- from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
-
- try:
- if not (is_transformers_available() and is_flax_available()):
- raise OptionalDependencyNotAvailable()
- except OptionalDependencyNotAvailable:
- from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
- else:
- from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
-
-
-else:
- import sys
-
- sys.modules[__name__] = _LazyModule(
- __name__,
- globals()["__file__"],
- _import_structure,
- module_spec=__spec__,
- )
- for name, value in _dummy_objects.items():
- setattr(sys.modules[__name__], name, value)
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_flax_available,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["multicontrolnet"] = ["MultiControlNetModel"]
+ _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
+ _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
+ _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
+ _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
+ _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
+ _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
+ _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
+try:
+ if not (is_transformers_available() and is_flax_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_flax_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
+else:
+ _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
+
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .multicontrolnet import MultiControlNetModel
+ from .pipeline_controlnet import StableDiffusionControlNetPipeline
+ from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
+ from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
+ from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
+ from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
+ from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
+ from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
+
+ try:
+ if not (is_transformers_available() and is_flax_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
index 1e19678b221d..f52b222ee129 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
@@ -20,10 +20,10 @@
import PIL.Image
import torch
import torch.nn.functional as F
-from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -35,7 +35,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -91,53 +91,8 @@
"""
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
- must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
-
-
class StableDiffusionControlNetPipeline(
- DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
@@ -147,7 +102,6 @@ class StableDiffusionControlNetPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
@@ -172,11 +126,9 @@ class StableDiffusionControlNetPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
- _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -188,7 +140,6 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
- image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -221,7 +172,6 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
- image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
@@ -474,24 +424,10 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
- dtype = next(self.image_encoder.parameters()).dtype
-
- if not isinstance(image, torch.Tensor):
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
-
- image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
-
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -548,21 +484,15 @@ def check_inputs(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
- callback_on_step_end_tensor_inputs=None,
):
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -796,58 +726,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
- """
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
-
- Args:
- timesteps (`torch.Tensor`):
- generate embedding vectors at these timesteps
- embedding_dim (`int`, *optional*, defaults to 512):
- dimension of the embeddings to generate
- dtype:
- data type of the generated embeddings
-
- Returns:
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
- """
- assert len(w.shape) == 1
- w = w * 1000.0
-
- half_dim = embedding_dim // 2
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
- emb = w.to(dtype)[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1))
- assert emb.shape == (w.shape[0], embedding_dim)
- return emb
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -857,7 +735,6 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -866,18 +743,16 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -900,10 +775,6 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -928,7 +799,6 @@ def __call__(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -957,15 +827,6 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -976,23 +837,6 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
-
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -1002,10 +846,9 @@ def __call__(
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
- control_guidance_start, control_guidance_end = (
- mult * [control_guidance_start],
- mult * [control_guidance_end],
- )
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
+ control_guidance_end
+ ]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -1018,13 +861,8 @@ def __call__(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
- callback_on_step_end_tensor_inputs,
)
- self._guidance_scale = guidance_scale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1034,6 +872,10 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
@@ -1047,30 +889,25 @@ def __call__(
# 3. Encode input prompt
text_encoder_lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
-
# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
image = self.prepare_image(
@@ -1081,7 +918,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = image.shape[-2:]
@@ -1097,7 +934,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1109,8 +946,8 @@ def __call__(
assert False
# 5. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
- self._num_timesteps = len(timesteps)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -1125,21 +962,10 @@ def __call__(
latents,
)
- # 6.5 Optionally get Guidance Scale Embedding
- timestep_cond = None
- if self.unet.config.time_cond_proj_dim is not None:
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
- timestep_cond = self.get_guidance_scale_embedding(
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
- ).to(device=device, dtype=latents.dtype)
-
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 7.1 Add image embeds for IP-Adapter
- added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
-
- # 7.2 Create tensor stating which controlnets to keep
+ # 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
@@ -1150,21 +976,14 @@ def __call__(
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- is_unet_compiled = is_compiled_module(self.unet)
- is_controlnet_compiled = is_compiled_module(self.controlnet)
- is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
- # Relevant thread:
- # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
- if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
- torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference
- if guess_mode and self.do_classifier_free_guidance:
+ if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
@@ -1191,7 +1010,7 @@ def __call__(
return_dict=False,
)
- if guess_mode and self.do_classifier_free_guidance:
+ if guess_mode and do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
@@ -1203,32 +1022,20 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- timestep_cond=timestep_cond,
- cross_attention_kwargs=self.cross_attention_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
- added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -1244,9 +1051,7 @@ def __call__(
torch.cuda.empty_cache()
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
- 0
- ]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
index fa489941c987..edeadb118925 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
@@ -91,20 +91,6 @@
"""
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
def prepare_image(image):
if isinstance(image, torch.Tensor):
# Batch single image
@@ -164,11 +150,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -464,7 +448,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -524,21 +508,15 @@ def check_inputs(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
- callback_on_step_end_tensor_inputs=None,
):
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -755,12 +733,11 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
elif isinstance(generator, list):
init_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
- for i in range(batch_size)
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
- init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = self.vae.config.scaling_factor * init_latents
@@ -819,29 +796,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -863,15 +817,14 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -927,6 +880,12 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -944,15 +903,6 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -963,23 +913,6 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
-
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -989,10 +922,9 @@ def __call__(
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
- control_guidance_start, control_guidance_end = (
- mult * [control_guidance_start],
- mult * [control_guidance_end],
- )
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
+ control_guidance_end
+ ]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -1005,13 +937,8 @@ def __call__(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
- callback_on_step_end_tensor_inputs,
)
- self._guidance_scale = guidance_scale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1021,6 +948,10 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
@@ -1034,27 +965,27 @@ def __call__(
# 3. Encode input prompt
text_encoder_lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare image
- image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image = self.image_processor.preprocess(image).to(dtype=torch.float32)
# 5. Prepare controlnet_conditioning_image
if isinstance(controlnet, ControlNetModel):
@@ -1066,7 +997,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
elif isinstance(controlnet, MultiControlNetModel):
@@ -1081,7 +1012,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1095,7 +1026,6 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
- self._num_timesteps = len(timesteps)
# 6. Prepare latent variables
latents = self.prepare_latents(
@@ -1125,11 +1055,11 @@ def __call__(
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference
- if guess_mode and self.do_classifier_free_guidance:
+ if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
@@ -1156,7 +1086,7 @@ def __call__(
return_dict=False,
)
- if guess_mode and self.do_classifier_free_guidance:
+ if guess_mode and do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
@@ -1168,30 +1098,20 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- cross_attention_kwargs=self.cross_attention_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
return_dict=False,
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -1207,9 +1127,7 @@ def __call__(
torch.cuda.empty_cache()
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
- 0
- ]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index 72c2250dd5ac..d25809a2e72a 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -21,10 +21,10 @@
import PIL.Image
import torch
import torch.nn.functional as F
-from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -68,16 +68,18 @@
>>> mask_image = mask_image.resize((512, 512))
- >>> def make_canny_condition(image):
- ... image = np.array(image)
- ... image = cv2.Canny(image, 100, 200)
- ... image = image[:, :, None]
- ... image = np.concatenate([image, image, image], axis=2)
- ... image = Image.fromarray(image)
+ >>> def make_inpaint_condition(image, image_mask):
+ ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
+ ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
+
+ ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
+ ... image[image_mask > 0.5] = -1.0 # set as masked pixel
+ ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
+ ... image = torch.from_numpy(image)
... return image
- >>> control_image = make_canny_condition(init_image)
+ >>> control_image = make_inpaint_condition(init_image, mask_image)
>>> controlnet = ControlNetModel.from_pretrained(
... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
@@ -103,20 +105,6 @@
"""
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
"""
@@ -241,7 +229,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
class StableDiffusionControlNetInpaintPipeline(
- DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for image inpainting using Stable Diffusion with ControlNet guidance.
@@ -251,7 +239,6 @@ class StableDiffusionControlNetInpaintPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
@@ -287,11 +274,9 @@ class StableDiffusionControlNetInpaintPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
- _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -303,7 +288,6 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
- image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -336,7 +320,6 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
- image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -592,24 +575,10 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
- dtype = next(self.image_encoder.parameters()).dtype
-
- if not isinstance(image, torch.Tensor):
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
-
- image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
-
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -678,24 +647,18 @@ def check_inputs(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
- callback_on_step_end_tensor_inputs=None,
):
if height is not None and height % 8 != 0 or width is not None and width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -988,12 +951,12 @@ def prepare_mask_latents(
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
- image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
image_latents = self.vae.config.scaling_factor * image_latents
@@ -1027,29 +990,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1070,18 +1010,16 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -1149,12 +1087,17 @@ def __call__(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -1172,15 +1115,6 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -1191,23 +1125,6 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
-
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -1217,10 +1134,9 @@ def __call__(
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
- control_guidance_start, control_guidance_end = (
- mult * [control_guidance_start],
- mult * [control_guidance_end],
- )
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
+ control_guidance_end
+ ]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -1235,13 +1151,8 @@ def __call__(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
- callback_on_step_end_tensor_inputs,
)
- self._guidance_scale = guidance_scale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1251,6 +1162,10 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
@@ -1264,30 +1179,25 @@ def __call__(
# 3. Encode input prompt
text_encoder_lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
-
# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
control_image = self.prepare_control_image(
@@ -1298,7 +1208,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
elif isinstance(controlnet, MultiControlNetModel):
@@ -1313,7 +1223,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1323,7 +1233,7 @@ def __call__(
else:
assert False
- # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
+ # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
@@ -1341,7 +1251,6 @@ def __call__(
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0
- self._num_timesteps = len(timesteps)
# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
@@ -1378,16 +1287,13 @@ def __call__(
prompt_embeds.dtype,
device,
generator,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 7.1 Add image embeds for IP-Adapter
- added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
-
- # 7.2 Create tensor stating which controlnets to keep
+ # 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
@@ -1401,11 +1307,11 @@ def __call__(
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference
- if guess_mode and self.do_classifier_free_guidance:
+ if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
@@ -1432,7 +1338,7 @@ def __call__(
return_dict=False,
)
- if guess_mode and self.do_classifier_free_guidance:
+ if guess_mode and do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
@@ -1447,15 +1353,14 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- cross_attention_kwargs=self.cross_attention_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
- added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
@@ -1464,7 +1369,7 @@ def __call__(
if num_channels_unet == 4:
init_latents_proper = image_latents
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
init_mask, _ = mask.chunk(2)
else:
init_mask = mask
@@ -1477,16 +1382,6 @@ def __call__(
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -1502,9 +1397,7 @@ def __call__(
torch.cuda.empty_cache()
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
- 0
- ]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
index 0f51ad58a598..4418ede74bd3 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
@@ -34,7 +34,6 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
- deprecate,
is_invisible_watermark_available,
logging,
replace_example_docstring,
@@ -54,20 +53,6 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -90,24 +75,27 @@ def retrieve_latents(
>>> mask_image = mask_image.resize((1024, 1024))
- >>> def make_canny_condition(image):
- ... image = np.array(image)
- ... image = cv2.Canny(image, 100, 200)
- ... image = image[:, :, None]
- ... image = np.concatenate([image, image, image], axis=2)
- ... image = Image.fromarray(image)
+ >>> def make_inpaint_condition(image, image_mask):
+ ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
+ ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
+
+ ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
+ ... image[image_mask < 0.5] = 0 # set as masked pixel
+ ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
+ ... image = torch.from_numpy(image)
... return image
- >>> control_image = make_canny_condition(init_image)
+ >>> control_image = make_inpaint_condition(init_image, mask_image)
>>> controlnet = ControlNetModel.from_pretrained(
- ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
+ ... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float32
... )
>>> pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
- ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
+ ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float32
... )
+ >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
>>> pipe.enable_model_cpu_offload()
>>> # generate image
@@ -179,10 +167,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
-
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
- _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["tokenizer", "text_encoder"]
def __init__(
self,
@@ -331,17 +317,12 @@ def encode_prompt(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
- if self.text_encoder is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
- else:
- scale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
- else:
- scale_lora_layers(self.text_encoder_2, lora_scale)
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -457,11 +438,7 @@ def encode_prompt(
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
- if self.text_encoder_2 is not None:
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -470,12 +447,7 @@ def encode_prompt(
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
-
- if self.text_encoder_2 is not None:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -487,15 +459,10 @@ def encode_prompt(
bs_embed * num_images_per_prompt, -1
)
- if self.text_encoder is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder_2, lora_scale)
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder)
+ unscale_lora_layers(self.text_encoder_2)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
@@ -571,7 +538,6 @@ def check_inputs(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
- callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -582,20 +548,14 @@ def check_inputs(
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
f" {type(num_inference_steps)}."
)
-
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -798,14 +758,13 @@ def prepare_latents(
"However, either the image or the noise timestep has not been provided."
)
- if return_image_latents or (latents is None and not is_strength_max):
+ if image.shape[1] == 4:
+ image_latents = image.to(device=device, dtype=dtype)
+ elif return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
- if image.shape[1] == 4:
- image_latents = image
- else:
- image_latents = self._encode_vae_image(image=image, generator=generator)
- image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
if latents is None and add_noise:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -838,12 +797,12 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
- image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
if self.vae.config.force_upcast:
self.vae.to(dtype)
@@ -920,32 +879,13 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
-
- num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
- if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
- # if the scheduler is a 2nd order scheduler we might have to do +1
- # because `num_inference_steps` might be even given that every timestep
- # (except the highest one) is duplicated. If `num_inference_steps` is even it would
- # mean that we cut the timesteps in the middle of the denoising step
- # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
- # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
- num_inference_steps = num_inference_steps + 1
-
- # because t_n+1 >= t_n, we slice the timesteps starting from the end
- timesteps = timesteps[-num_inference_steps:]
- return timesteps, num_inference_steps
+ timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
+ return torch.tensor(timesteps), len(timesteps)
return timesteps, num_inference_steps - t_start
def _get_add_time_ids(
- self,
- original_size,
- crops_coords_top_left,
- target_size,
- aesthetic_score,
- negative_aesthetic_score,
- dtype,
- text_encoder_projection_dim=None,
+ self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype
):
if self.config.requires_aesthetics_score:
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
@@ -955,7 +895,7 @@ def _get_add_time_ids(
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
passed_add_embed_dim = (
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
@@ -1031,29 +971,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1085,6 +1002,8 @@ def __call__(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False,
@@ -1097,9 +1016,6 @@ def __call__(
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -1194,6 +1110,12 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -1223,15 +1145,6 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -1240,23 +1153,6 @@ def __call__(
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
-
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -1266,10 +1162,9 @@ def __call__(
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
- control_guidance_start, control_guidance_end = (
- mult * [control_guidance_start],
- mult * [control_guidance_end],
- )
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
+ control_guidance_end
+ ]
# # 0.0 Default height and width to unet
# height = height or self.unet.config.sample_size * self.vae_scale_factor
@@ -1282,10 +1177,9 @@ def __call__(
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
- control_guidance_start, control_guidance_end = (
- mult * [control_guidance_start],
- mult * [control_guidance_end],
- )
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
+ control_guidance_end
+ ]
# 1. Check inputs
self.check_inputs(
@@ -1304,13 +1198,8 @@ def __call__(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
- callback_on_step_end_tensor_inputs,
)
- self._guidance_scale = guidance_scale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1320,13 +1209,17 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
# 3. Encode input prompt
text_encoder_lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
(
@@ -1339,7 +1232,7 @@ def __call__(
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
@@ -1347,7 +1240,7 @@ def __call__(
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
# 4. set timesteps
@@ -1368,7 +1261,6 @@ def denoising_value_valid(dnv):
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0
- self._num_timesteps = len(timesteps)
# 5. Preprocess mask and image - resizes image and mask w.r.t height and width
# 5.1 Prepare init image
@@ -1385,7 +1277,7 @@ def denoising_value_valid(dnv):
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
elif isinstance(controlnet, MultiControlNetModel):
@@ -1400,7 +1292,7 @@ def denoising_value_valid(dnv):
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1454,7 +1346,7 @@ def denoising_value_valid(dnv):
prompt_embeds.dtype,
device,
generator,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
)
# 8. Check that sizes of mask, masked image and latents match
@@ -1499,11 +1391,6 @@ def denoising_value_valid(dnv):
# 10. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
- if self.text_encoder_2 is None:
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
- else:
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
-
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
@@ -1511,11 +1398,10 @@ def denoising_value_valid(dnv):
aesthetic_score,
negative_aesthetic_score,
dtype=prompt_embeds.dtype,
- text_encoder_projection_dim=text_encoder_projection_dim,
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
@@ -1552,7 +1438,7 @@ def denoising_value_valid(dnv):
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1560,7 +1446,7 @@ def denoising_value_valid(dnv):
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# controlnet(s) inference
- if guess_mode and self.do_classifier_free_guidance:
+ if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
@@ -1597,7 +1483,7 @@ def denoising_value_valid(dnv):
return_dict=False,
)
- if guess_mode and self.do_classifier_free_guidance:
+ if guess_mode and do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
@@ -1612,7 +1498,7 @@ def denoising_value_valid(dnv):
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- cross_attention_kwargs=self.cross_attention_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs,
@@ -1620,11 +1506,11 @@ def denoising_value_valid(dnv):
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
- if self.do_classifier_free_guidance and guidance_rescale > 0.0:
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
@@ -1633,7 +1519,7 @@ def denoising_value_valid(dnv):
if num_channels_unet == 4:
init_latents_proper = image_latents
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
init_mask, _ = mask.chunk(2)
else:
init_mask = mask
@@ -1646,16 +1532,6 @@ def denoising_value_valid(dnv):
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -1686,8 +1562,9 @@ def denoising_value_valid(dnv):
image = self.image_processor.postprocess(image, output_type=output_type)
- # Offload all models
- self.maybe_free_model_hooks()
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
index 4696781dce0c..f634f3f389a9 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
@@ -20,23 +20,12 @@
import PIL.Image
import torch
import torch.nn.functional as F
-from transformers import (
- CLIPImageProcessor,
- CLIPTextModel,
- CLIPTextModelWithProjection,
- CLIPTokenizer,
- CLIPVisionModelWithProjection,
-)
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers.utils.import_utils import is_invisible_watermark_available
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import (
- FromSingleFileMixin,
- IPAdapterMixin,
- StableDiffusionXLLoraLoaderMixin,
- TextualInversionLoaderMixin,
-)
+from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
@@ -46,15 +35,8 @@
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import (
- USE_PEFT_BACKEND,
- deprecate,
- logging,
- replace_example_docstring,
- scale_lora_layers,
- unscale_lora_layers,
-)
-from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -115,11 +97,7 @@
class StableDiffusionXLControlNetPipeline(
- DiffusionPipeline,
- TextualInversionLoaderMixin,
- StableDiffusionXLLoraLoaderMixin,
- IPAdapterMixin,
- FromSingleFileMixin,
+ DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
@@ -161,18 +139,9 @@ class StableDiffusionXLControlNetPipeline(
watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
watermarker is used.
"""
-
- # leave controlnet out on purpose because it iterates with unet
- model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
- _optional_components = [
- "tokenizer",
- "tokenizer_2",
- "text_encoder",
- "text_encoder_2",
- "feature_extractor",
- "image_encoder",
- ]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ model_cpu_offload_seq = (
+ "text_encoder->text_encoder_2->unet->vae" # leave controlnet out on purpose because it iterates with unet
+ )
def __init__(
self,
@@ -186,8 +155,6 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
- feature_extractor: CLIPImageProcessor = None,
- image_encoder: CLIPVisionModelWithProjection = None,
):
super().__init__()
@@ -203,8 +170,6 @@ def __init__(
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
- feature_extractor=feature_extractor,
- image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
@@ -320,17 +285,12 @@ def encode_prompt(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
- if self.text_encoder is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
- else:
- scale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
- else:
- scale_lora_layers(self.text_encoder_2, lora_scale)
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -446,11 +406,7 @@ def encode_prompt(
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
- if self.text_encoder_2 is not None:
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -459,12 +415,7 @@ def encode_prompt(
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
-
- if self.text_encoder_2 is not None:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -476,32 +427,13 @@ def encode_prompt(
bs_embed * num_images_per_prompt, -1
)
- if self.text_encoder is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder_2, lora_scale)
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder)
+ unscale_lora_layers(self.text_encoder_2)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
- dtype = next(self.image_encoder.parameters()).dtype
-
- if not isinstance(image, torch.Tensor):
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
-
- image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
-
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -535,21 +467,15 @@ def check_inputs(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
- callback_on_step_end_tensor_inputs=None,
):
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -780,13 +706,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
return latents
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
- def _get_add_time_ids(
- self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
- ):
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
passed_add_embed_dim = (
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
@@ -846,58 +770,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
- """
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
-
- Args:
- timesteps (`torch.Tensor`):
- generate embedding vectors at these timesteps
- embedding_dim (`int`, *optional*, defaults to 512):
- dimension of the embeddings to generate
- dtype:
- data type of the generated embeddings
-
- Returns:
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
- """
- assert len(w.shape) == 1
- w = w * 1000.0
-
- half_dim = embedding_dim // 2
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
- emb = w.to(dtype)[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1))
- assert emb.shape == (w.shape[0], embedding_dim)
- return emb
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -919,9 +791,10 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False,
@@ -934,9 +807,6 @@ def __call__(
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -1000,12 +870,17 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
argument.
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -1052,15 +927,6 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -1069,23 +935,6 @@ def __call__(
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned containing the output images.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
-
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -1095,10 +944,9 @@ def __call__(
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
- control_guidance_start, control_guidance_end = (
- mult * [control_guidance_start],
- mult * [control_guidance_end],
- )
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
+ control_guidance_end
+ ]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -1115,13 +963,8 @@ def __call__(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
- callback_on_step_end_tensor_inputs,
)
- self._guidance_scale = guidance_scale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1131,6 +974,10 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
@@ -1142,9 +989,9 @@ def __call__(
)
guess_mode = guess_mode or global_pool_conditions
- # 3.1 Encode input prompt
+ # 3. Encode input prompt
text_encoder_lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
(
prompt_embeds,
@@ -1156,7 +1003,7 @@ def __call__(
prompt_2,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
negative_prompt_2,
prompt_embeds=prompt_embeds,
@@ -1164,15 +1011,9 @@ def __call__(
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
- # 3.2 Encode ip_adapter_image
- if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
-
# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
image = self.prepare_image(
@@ -1183,7 +1024,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = image.shape[-2:]
@@ -1199,7 +1040,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1213,7 +1054,6 @@ def __call__(
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
- self._num_timesteps = len(timesteps)
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -1228,14 +1068,6 @@ def __call__(
latents,
)
- # 6.5 Optionally get Guidance Scale Embedding
- timestep_cond = None
- if self.unet.config.time_cond_proj_dim is not None:
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
- timestep_cond = self.get_guidance_scale_embedding(
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
- ).to(device=device, dtype=latents.dtype)
-
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@@ -1256,17 +1088,8 @@ def __call__(
target_size = target_size or (height, width)
add_text_embeds = pooled_prompt_embeds
- if self.text_encoder_2 is None:
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
- else:
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
-
add_time_ids = self._get_add_time_ids(
- original_size,
- crops_coords_top_left,
- target_size,
- dtype=prompt_embeds.dtype,
- text_encoder_projection_dim=text_encoder_projection_dim,
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
)
if negative_original_size is not None and negative_target_size is not None:
@@ -1275,12 +1098,11 @@ def __call__(
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
- text_encoder_projection_dim=text_encoder_projection_dim,
)
else:
negative_add_time_ids = add_time_ids
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
@@ -1291,23 +1113,16 @@ def __call__(
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- is_unet_compiled = is_compiled_module(self.unet)
- is_controlnet_compiled = is_compiled_module(self.controlnet)
- is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
- # Relevant thread:
- # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
- if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
- torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# controlnet(s) inference
- if guess_mode and self.do_classifier_free_guidance:
+ if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
@@ -1340,23 +1155,19 @@ def __call__(
return_dict=False,
)
- if guess_mode and self.do_classifier_free_guidance:
+ if guess_mode and do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
- if ip_adapter_image is not None:
- added_cond_kwargs["image_embeds"] = image_embeds
-
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- timestep_cond=timestep_cond,
- cross_attention_kwargs=self.cross_attention_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs,
@@ -1364,23 +1175,13 @@ def __call__(
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
index ba18567b60f7..3375855ba8ee 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
@@ -37,7 +37,6 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
- deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -132,20 +131,6 @@
"""
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
class StableDiffusionXLControlNetImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin
):
@@ -197,10 +182,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
watermarker will be used.
"""
-
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
- _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["tokenizer", "text_encoder"]
def __init__(
self,
@@ -346,17 +329,12 @@ def encode_prompt(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
- if self.text_encoder is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
- else:
- scale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
- else:
- scale_lora_layers(self.text_encoder_2, lora_scale)
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -472,11 +450,7 @@ def encode_prompt(
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
- if self.text_encoder_2 is not None:
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -485,12 +459,7 @@ def encode_prompt(
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
-
- if self.text_encoder_2 is not None:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -502,15 +471,10 @@ def encode_prompt(
bs_embed * num_images_per_prompt, -1
)
- if self.text_encoder is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder_2, lora_scale)
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder)
+ unscale_lora_layers(self.text_encoder_2)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
@@ -549,7 +513,6 @@ def check_inputs(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
- callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -560,20 +523,14 @@ def check_inputs(
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
f" {type(num_inference_steps)}."
)
-
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -830,12 +787,11 @@ def prepare_latents(
elif isinstance(generator, list):
init_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
- for i in range(batch_size)
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
- init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
if self.vae.config.force_upcast:
self.vae.to(dtype)
@@ -876,7 +832,6 @@ def _get_add_time_ids(
negative_crops_coords_top_left,
negative_target_size,
dtype,
- text_encoder_projection_dim=None,
):
if self.config.requires_aesthetics_score:
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
@@ -888,7 +843,7 @@ def _get_add_time_ids(
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
passed_add_embed_dim = (
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
@@ -964,29 +919,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1012,6 +944,8 @@ def __call__(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
guess_mode: bool = False,
@@ -1026,9 +960,6 @@ def __call__(
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -1114,6 +1045,12 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -1169,15 +1106,6 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -1186,23 +1114,6 @@ def __call__(
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
containing the output images.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
-
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -1212,10 +1123,9 @@ def __call__(
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
- control_guidance_start, control_guidance_end = (
- mult * [control_guidance_start],
- mult * [control_guidance_end],
- )
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
+ control_guidance_end
+ ]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -1234,13 +1144,8 @@ def __call__(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
- callback_on_step_end_tensor_inputs,
)
- self._guidance_scale = guidance_scale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1250,6 +1155,10 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
@@ -1263,7 +1172,7 @@ def __call__(
# 3. Encode input prompt
text_encoder_lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
(
prompt_embeds,
@@ -1275,7 +1184,7 @@ def __call__(
prompt_2,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
negative_prompt_2,
prompt_embeds=prompt_embeds,
@@ -1283,7 +1192,7 @@ def __call__(
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
# 4. Prepare image and controlnet_conditioning_image
@@ -1298,7 +1207,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image.shape[-2:]
@@ -1314,7 +1223,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1329,7 +1238,6 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
- self._num_timesteps = len(timesteps)
# 6. Prepare latent variables
latents = self.prepare_latents(
@@ -1367,12 +1275,6 @@ def __call__(
if negative_target_size is None:
negative_target_size = target_size
add_text_embeds = pooled_prompt_embeds
-
- if self.text_encoder_2 is None:
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
- else:
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
-
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
@@ -1383,11 +1285,10 @@ def __call__(
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
- text_encoder_projection_dim=text_encoder_projection_dim,
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
@@ -1402,13 +1303,13 @@ def __call__(
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# controlnet(s) inference
- if guess_mode and self.do_classifier_free_guidance:
+ if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
@@ -1441,7 +1342,7 @@ def __call__(
return_dict=False,
)
- if guess_mode and self.do_classifier_free_guidance:
+ if guess_mode and do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
@@ -1453,7 +1354,7 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- cross_attention_kwargs=self.cross_attention_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs,
@@ -1461,23 +1362,13 @@ def __call__(
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -1515,8 +1406,9 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type)
- # Offload all models
- self.maybe_free_model_hooks()
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
index 36cb2c1dcca1..58326d5df471 100644
--- a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
+++ b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
@@ -39,7 +39,6 @@ class DanceDiffusionPipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
[`IPNDMScheduler`].
"""
-
model_cpu_offload_seq = "unet"
def __init__(self, unet, scheduler):
diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py
index 17d5b7a8c1c7..527e3f04c0f4 100644
--- a/src/diffusers/pipelines/ddim/pipeline_ddim.py
+++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py
@@ -35,7 +35,6 @@ class DDIMPipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
-
model_cpu_offload_seq = "unet"
def __init__(self, unet, scheduler):
diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
index ef916445ce0c..a07988fca842 100644
--- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
@@ -35,7 +35,6 @@ class DDPMPipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
-
model_cpu_offload_seq = "unet"
def __init__(self, unet, scheduler):
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
index 64806d783d51..aaf41529ce6d 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
@@ -98,19 +98,7 @@ class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
watermarker: Optional[IFWatermarker]
bad_punct_regex = re.compile(
- r"["
- + "#®•©™&@·º½¾¿¡§~"
- + r"\)"
- + r"\("
- + r"\]"
- + r"\["
- + r"\}"
- + r"\{"
- + r"\|"
- + "\\"
- + r"\/"
- + r"\*"
- + r"]{1,}"
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
@@ -173,11 +161,11 @@ def remove_all_hooks(self):
@torch.no_grad()
def encode_prompt(
self,
- prompt: Union[str, List[str]],
- do_classifier_free_guidance: bool = True,
- num_images_per_prompt: int = 1,
- device: Optional[torch.device] = None,
- negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt,
+ do_classifier_free_guidance=True,
+ num_images_per_prompt=1,
+ device=None,
+ negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
@@ -186,14 +174,14 @@ def encode_prompt(
Encodes the prompt into text encoder hidden states.
Args:
- prompt (`str` or `List[str]`, *optional*):
+ prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
- do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
- whether to use classifier free guidance or not
- num_images_per_prompt (`int`, *optional*, defaults to 1):
- number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
@@ -205,8 +193,6 @@ def encode_prompt(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
- clean_caption (bool, defaults to `False`):
- If `True`, the function will preprocess and clean the provided caption before encoding.
"""
if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt):
@@ -582,13 +568,13 @@ def __call__(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
- num_inference_steps (`int`, *optional*, defaults to 100):
+ num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 7.0):
+ guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
index 6ec4ce6f11f9..98654375efb8 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
@@ -122,19 +122,7 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
watermarker: Optional[IFWatermarker]
bad_punct_regex = re.compile(
- r"["
- + "#®•©™&@·º½¾¿¡§~"
- + r"\)"
- + r"\("
- + r"\]"
- + r"\["
- + r"\}"
- + r"\{"
- + r"\|"
- + "\\"
- + r"\/"
- + r"\*"
- + r"]{1,}"
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
@@ -196,13 +184,14 @@ def remove_all_hooks(self):
self.final_offload_hook = None
@torch.no_grad()
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
def encode_prompt(
self,
- prompt: Union[str, List[str]],
- do_classifier_free_guidance: bool = True,
- num_images_per_prompt: int = 1,
- device: Optional[torch.device] = None,
- negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt,
+ do_classifier_free_guidance=True,
+ num_images_per_prompt=1,
+ device=None,
+ negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
@@ -211,14 +200,14 @@ def encode_prompt(
Encodes the prompt into text encoder hidden states.
Args:
- prompt (`str` or `List[str]`, *optional*):
+ prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
- do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
- whether to use classifier free guidance or not
- num_images_per_prompt (`int`, *optional*, defaults to 1):
- number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
@@ -230,8 +219,6 @@ def encode_prompt(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
- clean_caption (bool, defaults to `False`):
- If `True`, the function will preprocess and clean the provided caption before encoding.
"""
if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt):
@@ -699,19 +686,19 @@ def __call__(
image (`torch.FloatTensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
- strength (`float`, *optional*, defaults to 0.7):
+ strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
- num_inference_steps (`int`, *optional*, defaults to 80):
+ num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 10.0):
+ guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
index d59c2b533dc1..7ee8168e3f61 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
@@ -126,19 +126,7 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
watermarker: Optional[IFWatermarker]
bad_punct_regex = re.compile(
- r"["
- + "#®•©™&@·º½¾¿¡§~"
- + r"\)"
- + r"\("
- + r"\]"
- + r"\["
- + r"\}"
- + r"\{"
- + r"\|"
- + "\\"
- + r"\/"
- + r"\*"
- + r"]{1,}"
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor"]
@@ -350,11 +338,11 @@ def _clean_caption(self, caption):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
def encode_prompt(
self,
- prompt: Union[str, List[str]],
- do_classifier_free_guidance: bool = True,
- num_images_per_prompt: int = 1,
- device: Optional[torch.device] = None,
- negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt,
+ do_classifier_free_guidance=True,
+ num_images_per_prompt=1,
+ device=None,
+ negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
@@ -363,14 +351,14 @@ def encode_prompt(
Encodes the prompt into text encoder hidden states.
Args:
- prompt (`str` or `List[str]`, *optional*):
+ prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
- do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
- whether to use classifier free guidance or not
- num_images_per_prompt (`int`, *optional*, defaults to 1):
- number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
@@ -382,8 +370,6 @@ def encode_prompt(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
- clean_caption (bool, defaults to `False`):
- If `True`, the function will preprocess and clean the provided caption before encoding.
"""
if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt):
@@ -798,7 +784,7 @@ def __call__(
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 4.0):
+ guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
index 1dbb5e92ec4c..cd867d065ec2 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
@@ -125,19 +125,7 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
watermarker: Optional[IFWatermarker]
bad_punct_regex = re.compile(
- r"["
- + "#®•©™&@·º½¾¿¡§~"
- + r"\)"
- + r"\("
- + r"\]"
- + r"\["
- + r"\}"
- + r"\{"
- + r"\|"
- + "\\"
- + r"\/"
- + r"\*"
- + r"]{1,}"
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
@@ -202,11 +190,11 @@ def remove_all_hooks(self):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
def encode_prompt(
self,
- prompt: Union[str, List[str]],
- do_classifier_free_guidance: bool = True,
- num_images_per_prompt: int = 1,
- device: Optional[torch.device] = None,
- negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt,
+ do_classifier_free_guidance=True,
+ num_images_per_prompt=1,
+ device=None,
+ negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
@@ -215,14 +203,14 @@ def encode_prompt(
Encodes the prompt into text encoder hidden states.
Args:
- prompt (`str` or `List[str]`, *optional*):
+ prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
- do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
- whether to use classifier free guidance or not
- num_images_per_prompt (`int`, *optional*, defaults to 1):
- number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
@@ -234,8 +222,6 @@ def encode_prompt(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
- clean_caption (bool, defaults to `False`):
- If `True`, the function will preprocess and clean the provided caption before encoding.
"""
if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt):
@@ -800,7 +786,7 @@ def __call__(
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
- strength (`float`, *optional*, defaults to 1.0):
+ strength (`float`, *optional*, defaults to 0.8):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
@@ -812,7 +798,7 @@ def __call__(
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 7.0):
+ guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
index cb9200cffce5..31e0baab6cbe 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
@@ -128,19 +128,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
watermarker: Optional[IFWatermarker]
bad_punct_regex = re.compile(
- r"["
- + "#®•©™&@·º½¾¿¡§~"
- + r"\)"
- + r"\("
- + r"\]"
- + r"\["
- + r"\}"
- + r"\{"
- + r"\|"
- + "\\"
- + r"\/"
- + r"\*"
- + r"]{1,}"
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
model_cpu_offload_seq = "text_encoder->unet"
@@ -352,11 +340,11 @@ def _clean_caption(self, caption):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
def encode_prompt(
self,
- prompt: Union[str, List[str]],
- do_classifier_free_guidance: bool = True,
- num_images_per_prompt: int = 1,
- device: Optional[torch.device] = None,
- negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt,
+ do_classifier_free_guidance=True,
+ num_images_per_prompt=1,
+ device=None,
+ negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
@@ -365,14 +353,14 @@ def encode_prompt(
Encodes the prompt into text encoder hidden states.
Args:
- prompt (`str` or `List[str]`, *optional*):
+ prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
- do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
- whether to use classifier free guidance or not
- num_images_per_prompt (`int`, *optional*, defaults to 1):
- number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
@@ -384,8 +372,6 @@ def encode_prompt(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
- clean_caption (bool, defaults to `False`):
- If `True`, the function will preprocess and clean the provided caption before encoding.
"""
if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt):
@@ -888,13 +874,13 @@ def __call__(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
- num_inference_steps (`int`, *optional*, defaults to 100):
+ num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 4.0):
+ guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
@@ -1121,6 +1107,8 @@ def __call__(
nsfw_detected = None
watermark_detected = None
+ if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
+ self.unet_offload_hook.offload()
else:
# 10. Post-processing
image = (image / 2 + 0.5).clamp(0, 1)
@@ -1129,7 +1117,9 @@ def __call__(
# 11. Run safety checker
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
- self.maybe_free_model_hooks()
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
if not return_dict:
return (image, nsfw_detected, watermark_detected)
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
index 2b48f5887c29..6e89df15156f 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
@@ -84,19 +84,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
watermarker: Optional[IFWatermarker]
bad_punct_regex = re.compile(
- r"["
- + "#®•©™&@·º½¾¿¡§~"
- + r"\)"
- + r"\("
- + r"\]"
- + r"\["
- + r"\}"
- + r"\{"
- + r"\|"
- + "\\"
- + r"\/"
- + r"\*"
- + r"]{1,}"
+ r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
@@ -308,11 +296,11 @@ def _clean_caption(self, caption):
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
def encode_prompt(
self,
- prompt: Union[str, List[str]],
- do_classifier_free_guidance: bool = True,
- num_images_per_prompt: int = 1,
- device: Optional[torch.device] = None,
- negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt,
+ do_classifier_free_guidance=True,
+ num_images_per_prompt=1,
+ device=None,
+ negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
clean_caption: bool = False,
@@ -321,14 +309,14 @@ def encode_prompt(
Encodes the prompt into text encoder hidden states.
Args:
- prompt (`str` or `List[str]`, *optional*):
+ prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
- do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
- whether to use classifier free guidance or not
- num_images_per_prompt (`int`, *optional*, defaults to 1):
- number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
@@ -340,8 +328,6 @@ def encode_prompt(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
- clean_caption (bool, defaults to `False`):
- If `True`, the function will preprocess and clean the provided caption before encoding.
"""
if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt):
@@ -651,19 +637,19 @@ def __call__(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
- height (`int`, *optional*, defaults to None):
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
- width (`int`, *optional*, defaults to None):
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
The image to be upscaled.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*, defaults to None):
+ timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
- guidance_scale (`float`, *optional*, defaults to 4.0):
+ guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py
index e5eed8c0c1da..022aa1202603 100644
--- a/src/diffusers/pipelines/dit/pipeline_dit.py
+++ b/src/diffusers/pipelines/dit/pipeline_dit.py
@@ -43,7 +43,6 @@ class DiTPipeline(DiffusionPipeline):
scheduler ([`DDIMScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
-
model_cpu_offload_seq = "transformer->vae"
def __init__(
@@ -167,6 +166,7 @@ def __call__(
# set step values
self.scheduler.set_timesteps(num_inference_steps)
+
for t in self.progress_bar(self.scheduler.timesteps):
if guidance_scale > 1:
half = latent_model_input[: len(latent_model_input) // 2]
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
index 5e7a69e756ce..5c78b0dce87e 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
@@ -388,8 +388,6 @@ def __call__(
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
- self.maybe_free_model_hooks()
-
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
index eff8af4c723e..25508e1e080f 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
@@ -321,9 +321,6 @@ def __call__(
callback_steps=callback_steps,
return_dict=return_dict,
)
-
- self.maybe_free_model_hooks()
-
return outputs
@@ -561,9 +558,6 @@ def __call__(
callback_steps=callback_steps,
return_dict=return_dict,
)
-
- self.maybe_free_model_hooks()
-
return outputs
@@ -599,7 +593,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
"""
_load_connected_pipes = True
- model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
+ model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq"
def __init__(
self,
@@ -808,7 +802,4 @@ def __call__(
callback_steps=callback_steps,
return_dict=return_dict,
)
-
- self.maybe_free_model_hooks()
-
return outputs
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
index c5e7af270906..a22823aadef4 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
@@ -481,8 +481,6 @@ def __call__(
# 7. post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
- self.maybe_free_model_hooks()
-
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
index e9b5eb5cdd70..144e3ce585af 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
@@ -616,8 +616,6 @@ def __call__(
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
- self.maybe_free_model_hooks()
-
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
index a9c12b258974..c9a6019a8eac 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
@@ -527,7 +527,7 @@ def __call__(
if negative_prompt is None:
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
- self.maybe_free_model_hooks()
+ self.maybe_free_model_hooks
else:
image_embeddings, zero_embeds = image_embeddings.chunk(2)
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
index d87aa9ff2d19..3d7b09471969 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
@@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Callable, Dict, List, Optional, Union
+from typing import Callable, List, Optional, Union
import torch
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
-from ...utils import deprecate, logging, replace_example_docstring
+from ...utils import (
+ logging,
+ replace_example_docstring,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -78,7 +81,6 @@ class KandinskyV22Pipeline(DiffusionPipeline):
"""
model_cpu_offload_seq = "unet->movq"
- _callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds"]
def __init__(
self,
@@ -107,18 +109,6 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
latents = latents * scheduler.init_noise_sigma
return latents
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -133,10 +123,9 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
return_dict: bool = True,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
"""
Function invoked when calling the pipeline for generation.
@@ -171,50 +160,23 @@ def __call__(
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
-
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
device = self._execution_device
- self._guidance_scale = guidance_scale
+ do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(image_embeds, list):
image_embeds = torch.cat(image_embeds, dim=0)
@@ -222,7 +184,7 @@ def __call__(
if isinstance(negative_image_embeds, list):
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
@@ -231,7 +193,7 @@ def __call__(
)
self.scheduler.set_timesteps(num_inference_steps, device=device)
- timesteps = self.scheduler.timesteps
+ timesteps_tensor = self.scheduler.timesteps
num_channels_latents = self.unet.config.in_channels
@@ -247,10 +209,9 @@ def __call__(
self.scheduler,
)
- self._num_timesteps = len(timesteps)
- for i, t in enumerate(self.progress_bar(timesteps)):
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
added_cond_kwargs = {"image_embeds": image_embeds}
noise_pred = self.unet(
@@ -261,11 +222,11 @@ def __call__(
return_dict=False,
)[0]
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
_, variance_pred_text = variance_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
if not (
@@ -282,37 +243,24 @@ def __call__(
generator=generator,
)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- image_embeds = callback_outputs.pop("image_embeds", image_embeds)
- negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds)
-
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ # post-processing
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
- if output_type not in ["pt", "np", "pil", "latent"]:
- raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
+ self.maybe_free_model_hooks()
- if not output_type == "latent":
- # post-processing
- image = self.movq.decode(latents, force_not_quantize=True)["sample"]
- if output_type in ["np", "pil"]:
- image = image * 0.5 + 0.5
- image = image.clamp(0, 1)
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
- if output_type == "pil":
- image = self.numpy_to_pil(image)
- else:
- image = latents
+ if output_type in ["np", "pil"]:
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
- self.maybe_free_model_hooks()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
index 2b8a49976fc9..07c242c9fca7 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Callable, Dict, List, Optional, Union
+from typing import Callable, List, Optional, Union
import PIL.Image
import torch
@@ -20,7 +20,10 @@
from ...models import PriorTransformer, UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler, UnCLIPScheduler
-from ...utils import deprecate, logging, replace_example_docstring
+from ...utils import (
+ logging,
+ replace_example_docstring,
+)
from ..pipeline_utils import DiffusionPipeline
from .pipeline_kandinsky2_2 import KandinskyV22Pipeline
from .pipeline_kandinsky2_2_img2img import KandinskyV22Img2ImgPipeline
@@ -217,10 +220,6 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
return_dict: bool = True,
- prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
"""
Function invoked when calling the pipeline for generation.
@@ -265,25 +264,14 @@ def __call__(
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
- prior_callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference of the prior pipeline.
- The function is called with the following arguments: `prior_callback_on_step_end(self:
- DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
- prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
- list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
- the `._callback_tensor_inputs` attribute of your prior pipeline class.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference of the decoder pipeline.
- The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline,
- step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors
- as specified by `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
@@ -300,8 +288,6 @@ def __call__(
guidance_scale=prior_guidance_scale,
output_type="pt",
return_dict=False,
- callback_on_step_end=prior_callback_on_step_end,
- callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
)
image_embeds = prior_outputs[0]
negative_image_embeds = prior_outputs[1]
@@ -323,11 +309,7 @@ def __call__(
callback=callback,
callback_steps=callback_steps,
return_dict=return_dict,
- callback_on_step_end=callback_on_step_end,
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
- self.maybe_free_model_hooks()
-
return outputs
@@ -456,10 +438,6 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
return_dict: bool = True,
- prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
"""
Function invoked when calling the pipeline for generation.
@@ -538,8 +516,6 @@ def __call__(
guidance_scale=prior_guidance_scale,
output_type="pt",
return_dict=False,
- callback_on_step_end=prior_callback_on_step_end,
- callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
)
image_embeds = prior_outputs[0]
negative_image_embeds = prior_outputs[1]
@@ -571,11 +547,7 @@ def __call__(
callback=callback,
callback_steps=callback_steps,
return_dict=return_dict,
- callback_on_step_end=callback_on_step_end,
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
-
- self.maybe_free_model_hooks()
return outputs
@@ -691,12 +663,9 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
return_dict: bool = True,
- prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
"""
Function invoked when calling the pipeline for generation.
@@ -750,48 +719,20 @@ def __call__(
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
- prior_callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep:
- int, callback_kwargs: Dict)`.
- prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
- list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
- the `._callback_tensor_inputs` attribute of your pipeline class.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
-
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
- prior_kwargs = {}
- if kwargs.get("prior_callback", None) is not None:
- prior_kwargs["callback"] = kwargs.pop("prior_callback")
- deprecate(
- "prior_callback",
- "1.0.0",
- "Passing `prior_callback` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`",
- )
- if kwargs.get("prior_callback_steps", None) is not None:
- deprecate(
- "prior_callback_steps",
- "1.0.0",
- "Passing `prior_callback_steps` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`",
- )
- prior_kwargs["callback_steps"] = kwargs.pop("prior_callback_steps")
-
prior_outputs = self.prior_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
@@ -802,9 +743,6 @@ def __call__(
guidance_scale=prior_guidance_scale,
output_type="pt",
return_dict=False,
- callback_on_step_end=prior_callback_on_step_end,
- callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
- **prior_kwargs,
)
image_embeds = prior_outputs[0]
negative_image_embeds = prior_outputs[1]
@@ -841,11 +779,8 @@ def __call__(
generator=generator,
guidance_scale=guidance_scale,
output_type=output_type,
+ callback=callback,
+ callback_steps=callback_steps,
return_dict=return_dict,
- callback_on_step_end=callback_on_step_end,
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
- **kwargs,
)
- self.maybe_free_model_hooks()
-
return outputs
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
index 92343e2667e6..8cf3735672a8 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Callable, Dict, List, Optional, Union
+from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image
@@ -21,7 +21,9 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
-from ...utils import deprecate, logging
+from ...utils import (
+ logging,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -106,7 +108,6 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
"""
model_cpu_offload_seq = "unet->movq"
- _callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds"]
def __init__(
self,
@@ -175,18 +176,6 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
return latents
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
def __call__(
self,
@@ -201,10 +190,9 @@ def __call__(
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
return_dict: bool = True,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
"""
Function invoked when calling the pipeline for generation.
@@ -245,50 +233,23 @@ def __call__(
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
-
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
device = self._execution_device
- self._guidance_scale = guidance_scale
+ do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(image_embeds, list):
image_embeds = torch.cat(image_embeds, dim=0)
@@ -296,7 +257,7 @@ def __call__(
if isinstance(negative_image_embeds, list):
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
@@ -323,10 +284,9 @@ def __call__(
latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
)
- self._num_timesteps = len(timesteps)
for i, t in enumerate(self.progress_bar(timesteps)):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
added_cond_kwargs = {"image_embeds": image_embeds}
noise_pred = self.unet(
@@ -337,11 +297,11 @@ def __call__(
return_dict=False,
)[0]
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
_, variance_pred_text = variance_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
if not (
@@ -358,41 +318,27 @@ def __call__(
generator=generator,
)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- image_embeds = callback_outputs.pop("image_embeds", image_embeds)
- negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds)
-
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
- if output_type not in ["pt", "np", "pil", "latent"]:
- raise ValueError(
- f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
- )
-
- if not output_type == "latent":
- # post-processing
- image = self.movq.decode(latents, force_not_quantize=True)["sample"]
- if output_type in ["np", "pil"]:
- image = image * 0.5 + 0.5
- image = image.clamp(0, 1)
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
-
- if output_type == "pil":
- image = self.numpy_to_pil(image)
- else:
- image = latents
+ # post-processing
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
# Offload all models
self.maybe_free_model_hooks()
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
+
+ if output_type in ["np", "pil"]:
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
index 66e62303f3f6..7a9326b708e5 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
@@ -13,7 +13,7 @@
# limitations under the License.
from copy import deepcopy
-from typing import Callable, Dict, List, Optional, Union
+from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image
@@ -25,7 +25,9 @@
from ... import __version__
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
-from ...utils import deprecate, logging
+from ...utils import (
+ logging,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -249,7 +251,6 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
"""
model_cpu_offload_seq = "unet->movq"
- _callback_tensor_inputs = ["latents", "image_embeds", "negative_image_embeds", "masked_image", "mask_image"]
def __init__(
self,
@@ -279,18 +280,6 @@ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
latents = latents * scheduler.init_noise_sigma
return latents
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
def __call__(
self,
@@ -306,10 +295,9 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
return_dict: bool = True,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
"""
Function invoked when calling the pipeline for generation.
@@ -352,17 +340,14 @@ def __call__(
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
@@ -382,40 +367,17 @@ def __call__(
)
self._warn_has_been_called = True
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
-
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
- self._guidance_scale = guidance_scale
-
device = self._execution_device
+ do_classifier_free_guidance = guidance_scale > 1.0
+
if isinstance(image_embeds, list):
image_embeds = torch.cat(image_embeds, dim=0)
batch_size = image_embeds.shape[0] * num_images_per_prompt
if isinstance(negative_image_embeds, list):
negative_image_embeds = torch.cat(negative_image_embeds, dim=0)
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
@@ -424,7 +386,7 @@ def __call__(
)
self.scheduler.set_timesteps(num_inference_steps, device=device)
- timesteps = self.scheduler.timesteps
+ timesteps_tensor = self.scheduler.timesteps
# preprocess image and mask
mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width)
@@ -445,7 +407,7 @@ def __call__(
mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0)
masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0)
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
mask_image = mask_image.repeat(2, 1, 1, 1)
masked_image = masked_image.repeat(2, 1, 1, 1)
@@ -463,11 +425,9 @@ def __call__(
self.scheduler,
)
noise = torch.clone(latents)
-
- self._num_timesteps = len(timesteps)
- for i, t in enumerate(self.progress_bar(timesteps)):
+ for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1)
added_cond_kwargs = {"image_embeds": image_embeds}
@@ -479,11 +439,11 @@ def __call__(
return_dict=False,
)[0]
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
_, variance_pred_text = variance_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1)
if not (
@@ -502,53 +462,35 @@ def __call__(
init_latents_proper = image[:1]
init_mask = mask_image[:1]
- if i < len(timesteps) - 1:
- noise_timestep = timesteps[i + 1]
+ if i < len(timesteps_tensor) - 1:
+ noise_timestep = timesteps_tensor[i + 1]
init_latents_proper = self.scheduler.add_noise(
init_latents_proper, noise, torch.tensor([noise_timestep])
)
latents = init_mask * init_latents_proper + (1 - init_mask) * latents
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- image_embeds = callback_outputs.pop("image_embeds", image_embeds)
- negative_image_embeds = callback_outputs.pop("negative_image_embeds", negative_image_embeds)
- masked_image = callback_outputs.pop("masked_image", masked_image)
- mask_image = callback_outputs.pop("mask_image", mask_image)
-
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing
latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents
+ image = self.movq.decode(latents, force_not_quantize=True)["sample"]
- if output_type not in ["pt", "np", "pil", "latent"]:
- raise ValueError(
- f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
- )
-
- if not output_type == "latent":
- image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+ # Offload all models
+ self.maybe_free_model_hooks()
- if output_type in ["np", "pil"]:
- image = image * 0.5 + 0.5
- image = image.clamp(0, 1)
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
- if output_type == "pil":
- image = self.numpy_to_pil(image)
- else:
- image = latents
+ if output_type in ["np", "pil"]:
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
- # Offload all models
- self.maybe_free_model_hooks()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
index 83427c68f208..2bdd049d1c24 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
@@ -1,4 +1,4 @@
-from typing import Callable, Dict, List, Optional, Union
+from typing import List, Optional, Union
import PIL.Image
import torch
@@ -106,7 +106,6 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
model_cpu_offload_seq = "text_encoder->image_encoder->prior"
_exclude_from_cpu_offload = ["prior"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "text_encoder_hidden_states", "text_mask"]
def __init__(
self,
@@ -355,18 +354,6 @@ def _encode_prompt(
return prompt_embeds, text_encoder_hidden_states, text_mask
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -380,8 +367,6 @@ def __call__(
guidance_scale: float = 4.0,
output_type: Optional[str] = "pt", # pt only
return_dict: bool = True,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
"""
Function invoked when calling the pipeline for generation.
@@ -415,15 +400,6 @@ def __call__(
(`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
@@ -431,13 +407,6 @@ def __call__(
[`KandinskyPriorPipelineOutput`] or `tuple`
"""
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if isinstance(prompt, str):
prompt = [prompt]
elif not isinstance(prompt, list):
@@ -459,15 +428,14 @@ def __call__(
batch_size = len(prompt)
batch_size = batch_size * num_images_per_prompt
- self._guidance_scale = guidance_scale
-
+ do_classifier_free_guidance = guidance_scale > 1.0
prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
- prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
# prior
self.scheduler.set_timesteps(num_inference_steps, device=device)
- timesteps = self.scheduler.timesteps
+ prior_timesteps_tensor = self.scheduler.timesteps
embedding_dim = self.prior.config.embedding_dim
@@ -479,10 +447,10 @@ def __call__(
latents,
self.scheduler,
)
- self._num_timesteps = len(timesteps)
- for i, t in enumerate(self.progress_bar(timesteps)):
+
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
predicted_image_embedding = self.prior(
latent_model_input,
@@ -492,16 +460,16 @@ def __call__(
attention_mask=text_mask,
).predicted_image_embedding
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
- predicted_image_embedding = predicted_image_embedding_uncond + self.guidance_scale * (
+ predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * (
predicted_image_embedding_text - predicted_image_embedding_uncond
)
- if i + 1 == timesteps.shape[0]:
+ if i + 1 == prior_timesteps_tensor.shape[0]:
prev_timestep = None
else:
- prev_timestep = timesteps[i + 1]
+ prev_timestep = prior_timesteps_tensor[i + 1]
latents = self.scheduler.step(
predicted_image_embedding,
@@ -511,19 +479,6 @@ def __call__(
prev_timestep=prev_timestep,
).prev_sample
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- text_encoder_hidden_states = callback_outputs.pop(
- "text_encoder_hidden_states", text_encoder_hidden_states
- )
- text_mask = callback_outputs.pop("text_mask", text_mask)
-
latents = self.prior.post_process_latents(latents)
image_embeddings = latents
@@ -531,10 +486,14 @@ def __call__(
# if negative prompt has been defined, we retrieve split the image embedding into two
if negative_prompt is None:
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
+
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
else:
image_embeddings, zero_embeds = image_embeddings.chunk(2)
- self.maybe_free_model_hooks()
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.prior_hook.offload()
if output_type not in ["pt", "np"]:
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
index bef70821c605..b4a6a64137ec 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
@@ -545,10 +545,12 @@ def __call__(
# if negative prompt has been defined, we retrieve split the image embedding into two
if negative_prompt is None:
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
else:
image_embeddings, zero_embeds = image_embeddings.chunk(2)
-
- self.maybe_free_model_hooks()
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.prior_hook.offload()
if output_type not in ["pt", "np"]:
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
index 99b9c9f65f82..cedf9de01475 100644
--- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
@@ -49,7 +49,6 @@ class LDMTextToImagePipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
-
model_cpu_offload_seq = "bert->unet->vqvae"
def __init__(
diff --git a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
index 68af3925fa02..9e6b6fea13e5 100644
--- a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
+++ b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
@@ -51,7 +51,7 @@
>>> import torch
>>> import scipy
- >>> repo_id = "ucsd-reach/musicldm"
+ >>> repo_id = "cvssp/audioldm-s-full-v2"
>>> pipe = MusicLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
index 0a20981beb05..a782caa55efc 100644
--- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
+++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
@@ -34,20 +34,6 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
def prepare_mask_and_masked_image(image, mask):
"""
Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be
@@ -181,7 +167,6 @@ class PaintByExamplePipeline(DiffusionPipeline):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
# TODO: feature_extractor is required to encode initial images (if they are in PIL format),
# we should give a descriptive message if the pipeline doesn't have one.
@@ -349,12 +334,12 @@ def prepare_mask_latents(
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
- image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
image_latents = self.vae.config.scaling_factor * image_latents
diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py
index 2e25a40295b4..7b067405cace 100644
--- a/src/diffusers/pipelines/pipeline_flax_utils.py
+++ b/src/diffusers/pipelines/pipeline_flax_utils.py
@@ -112,7 +112,6 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
- **config_name** ([`str`]) -- The configuration filename that stores the class and module names of all the
diffusion pipeline's components.
"""
-
config_name = "model_index.json"
def register_modules(self, **kwargs):
@@ -538,13 +537,12 @@ def load_module(name, value):
model = pipeline_class(**init_kwargs, dtype=dtype)
return model, params
- @classmethod
- def _get_signature_keys(cls, obj):
+ @staticmethod
+ def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - {"self"}
-
return expected_modules, optional_parameters
@property
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index b84344fab85e..bad23a60293f 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -33,6 +33,8 @@
from requests.exceptions import HTTPError
from tqdm.auto import tqdm
+import diffusers
+
from .. import __version__
from ..configuration_utils import ConfigMixin
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
@@ -49,7 +51,6 @@
get_class_from_dynamic_module,
is_accelerate_available,
is_accelerate_version,
- is_peft_available,
is_torch_version,
is_transformers_available,
logging,
@@ -159,9 +160,9 @@ def is_safetensors_compatible(filenames, variant=None, passed_components=None) -
continue
if extension == ".bin":
- pt_filenames.append(os.path.normpath(filename))
+ pt_filenames.append(filename)
elif extension == ".safetensors":
- sf_filenames.add(os.path.normpath(filename))
+ sf_filenames.add(filename)
for filename in pt_filenames:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
@@ -173,8 +174,9 @@ def is_safetensors_compatible(filenames, variant=None, passed_components=None) -
else:
filename = filename
- expected_sf_filename = os.path.normpath(os.path.join(path, filename))
+ expected_sf_filename = os.path.join(path, filename)
expected_sf_filename = f"{expected_sf_filename}.safetensors"
+
if expected_sf_filename not in sf_filenames:
logger.warning(f"{expected_sf_filename} not found")
return False
@@ -259,7 +261,7 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token,
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
- if set(model_filenames).issubset(set(comp_model_filenames)):
+ if set(comp_model_filenames) == set(model_filenames):
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
@@ -271,20 +273,6 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token,
)
-def _unwrap_model(model):
- """Unwraps a model."""
- if is_compiled_module(model):
- model = model._orig_mod
-
- if is_peft_available():
- from peft import PeftModel
-
- if isinstance(model, PeftModel):
- model = model.base_model.model
-
- return model
-
-
def maybe_raise_or_warn(
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
):
@@ -302,8 +290,9 @@ def maybe_raise_or_warn(
# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
sub_model = passed_class_obj[name]
- unwrapped_sub_model = _unwrap_model(sub_model)
- model_cls = unwrapped_sub_model.__class__
+ model_cls = sub_model.__class__
+ if is_compiled_module(sub_model):
+ model_cls = sub_model._orig_mod.__class__
if not issubclass(model_cls, expected_class_obj):
raise ValueError(
@@ -316,23 +305,13 @@ def maybe_raise_or_warn(
)
-def get_class_obj_and_candidates(
- library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
-):
+def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module):
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
- component_folder = os.path.join(cache_dir, component_name)
-
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
class_candidates = {c: class_obj for c in importable_classes.keys()}
- elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
- # load custom component
- class_obj = get_class_from_dynamic_module(
- component_folder, module_file=library_name + ".py", class_name=class_name
- )
- class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
@@ -344,15 +323,7 @@ def get_class_obj_and_candidates(
def _get_pipeline_class(
- class_obj,
- config,
- load_connected_pipeline=False,
- custom_pipeline=None,
- repo_id=None,
- hub_revision=None,
- class_name=None,
- cache_dir=None,
- revision=None,
+ class_obj, config, load_connected_pipeline=False, custom_pipeline=None, cache_dir=None, revision=None
):
if custom_pipeline is not None:
if custom_pipeline.endswith(".py"):
@@ -360,24 +331,11 @@ def _get_pipeline_class(
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
- elif repo_id is not None:
- file_name = f"{custom_pipeline}.py"
- custom_pipeline = repo_id
else:
file_name = CUSTOM_PIPELINE_FILE_NAME
- if repo_id is not None and hub_revision is not None:
- # if we load the pipeline code from the Hub
- # make sure to overwrite the `revison`
- revision = hub_revision
-
return get_class_from_dynamic_module(
- custom_pipeline,
- module_file=file_name,
- class_name=class_name,
- repo_id=repo_id,
- cache_dir=cache_dir,
- revision=revision,
+ custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=revision
)
if class_obj != DiffusionPipeline:
@@ -425,18 +383,11 @@ def load_sub_model(
variant: str,
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
- revision: str = None,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
# retrieve class candidates
class_obj, class_candidates = get_class_obj_and_candidates(
- library_name,
- class_name,
- importable_classes,
- pipelines,
- is_pipeline_module,
- component_name=name,
- cache_dir=cached_folder,
+ library_name, class_name, importable_classes, pipelines, is_pipeline_module
)
load_method_name = None
@@ -463,15 +414,14 @@ def load_sub_model(
load_method = getattr(class_obj, load_method_name)
# add kwargs to loading method
- diffusers_module = importlib.import_module(__name__.split(".")[0])
loading_kwargs = {}
if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype
- if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
+ if issubclass(class_obj, diffusers.OnnxRuntimeModel):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
- is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
+ is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
@@ -542,7 +492,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
- **_optional_components** (`List[str]`) -- List of all optional components that don't have to be passed to the
pipeline to function (should be overridden by subclasses).
"""
-
config_name = "model_index.json"
model_cpu_offload_seq = None
_optional_components = []
@@ -552,16 +501,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
def register_modules(self, **kwargs):
# import it here to avoid circular import
- diffusers_module = importlib.import_module(__name__.split(".")[0])
- pipelines = getattr(diffusers_module, "pipelines")
+ from diffusers import pipelines
for name, module in kwargs.items():
# retrieve library
- if module is None or isinstance(module, (tuple, list)) and module[0] is None:
+ if module is None:
register_dict = {name: (None, None)}
else:
# register the config from the original module, not the dynamo compiled one
- not_compiled_module = _unwrap_model(module)
+ if is_compiled_module(module):
+ not_compiled_module = module._orig_mod
+ else:
+ not_compiled_module = module
library = not_compiled_module.__module__.split(".")[0]
@@ -664,7 +615,7 @@ def is_saveable_module(name, value):
# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
if is_compiled_module(sub_model):
- sub_model = _unwrap_model(sub_model)
+ sub_model = sub_model._orig_mod
model_cls = sub_model.__class__
save_method_name = None
@@ -1129,21 +1080,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
- custom_class_name = None
- if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
- custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
- elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
- os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
- ):
- custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
- custom_class_name = config_dict["_class_name"][1]
-
pipeline_class = _get_pipeline_class(
cls,
config_dict,
load_connected_pipeline=load_connected_pipeline,
custom_pipeline=custom_pipeline,
- class_name=custom_class_name,
cache_dir=cache_dir,
revision=custom_revision,
)
@@ -1282,7 +1223,6 @@ def load_module(name, value):
variant=variant,
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
- revision=revision,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -1602,10 +1542,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
`False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending
with `.onnx` and `.pb`.
- trust_remote_code (`bool`, *optional*, defaults to `False`):
- Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
- option should only be set to `True` for repositories you trust and in which you have read the code, as
- it will execute code present on the Hub on your local machine.
Returns:
`os.PathLike`:
@@ -1633,7 +1569,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
- trust_remote_code = kwargs.pop("trust_remote_code", False)
allow_pickle = False
if use_safetensors is None:
@@ -1669,35 +1604,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)
config_dict = cls._dict_from_json_file(config_file)
+
ignore_filenames = config_dict.pop("_ignore_files", [])
# retrieve all folder_names that contain relevant files
- folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
+ folder_names = [k for k, v in config_dict.items() if isinstance(v, list)]
filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
- diffusers_module = importlib.import_module(__name__.split(".")[0])
- pipelines = getattr(diffusers_module, "pipelines")
-
- # optionally create a custom component <> custom file mapping
- custom_components = {}
- for component in folder_names:
- module_candidate = config_dict[component][0]
-
- if module_candidate is None or not isinstance(module_candidate, str):
- continue
-
- # We compute candidate file path on the Hub. Do not use `os.path.join`.
- candidate_file = f"{component}/{module_candidate}.py"
-
- if candidate_file in filenames:
- custom_components[component] = module_candidate
- elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
- raise ValueError(
- f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
- )
-
if len(variant_filenames) == 0 and variant is not None:
deprecation_message = (
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
@@ -1721,21 +1636,12 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
- custom_class_name = None
- if custom_pipeline is None and isinstance(config_dict["_class_name"], (list, tuple)):
- custom_pipeline = config_dict["_class_name"][0]
- custom_class_name = config_dict["_class_name"][1]
-
# all filenames compatible with variant will be added
allow_patterns = list(model_filenames)
# allow all patterns from non-model folders
# this enables downloading schedulers, tokenizers, ...
allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names]
- # add custom component files
- allow_patterns += [f"{k}/{f}.py" for k, f in custom_components.items()]
- # add custom pipeline file
- allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
@@ -1746,32 +1652,12 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
CUSTOM_PIPELINE_FILE_NAME,
]
- load_pipe_from_hub = custom_pipeline is not None and f"{custom_pipeline}.py" in filenames
- load_components_from_hub = len(custom_components) > 0
-
- if load_pipe_from_hub and not trust_remote_code:
- raise ValueError(
- f"The repository for {pretrained_model_name} contains custom code in {custom_pipeline}.py which must be executed to correctly "
- f"load the model. You can inspect the repository content at https://hf.co/{pretrained_model_name}/blob/main/{custom_pipeline}.py.\n"
- f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
- )
-
- if load_components_from_hub and not trust_remote_code:
- raise ValueError(
- f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
- f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
- f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
- )
-
# retrieve passed components that should not be downloaded
pipeline_class = _get_pipeline_class(
cls,
config_dict,
load_connected_pipeline=load_connected_pipeline,
custom_pipeline=custom_pipeline,
- repo_id=pretrained_model_name if load_pipe_from_hub else None,
- hub_revision=revision,
- class_name=custom_class_name,
cache_dir=cache_dir,
revision=custom_revision,
)
@@ -1786,7 +1672,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)
):
raise EnvironmentError(
- f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
+ f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
@@ -1868,10 +1754,9 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
# retrieve pipeline class from local file
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
- cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
+ cls_name = cls_name[4:] if cls_name.startswith("Flax") else cls_name
- diffusers_module = importlib.import_module(__name__.split(".")[0])
- pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
+ pipeline_class = getattr(diffusers, cls_name, None)
if pipeline_class is not None and pipeline_class._load_connected_pipes:
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
@@ -1907,19 +1792,12 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
" above."
) from model_info_call_error
- @classmethod
- def _get_signature_keys(cls, obj):
+ @staticmethod
+ def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - {"self"}
-
- optional_names = list(optional_parameters)
- for name in optional_names:
- if name in cls._optional_components:
- expected_modules.add(name)
- optional_parameters.remove(name)
-
return expected_modules, optional_parameters
@property
diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
index 11d1af710355..eb98479b9b61 100644
--- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
+++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
@@ -35,7 +35,6 @@ class ScoreSdeVePipeline(DiffusionPipeline):
scheduler ([`ScoreSdeVeScheduler`]):
A `ScoreSdeVeScheduler` to be used in combination with `unet` to denoise the encoded image.
"""
-
unet: UNet2DModel
scheduler: ScoreSdeVeScheduler
diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
index 19bd1f16152c..c467d5ebe829 100644
--- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
+++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
@@ -146,22 +146,17 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
diff --git a/src/diffusers/pipelines/shap_e/renderer.py b/src/diffusers/pipelines/shap_e/renderer.py
index 2145bc25c40a..ac5c06042e59 100644
--- a/src/diffusers/pipelines/shap_e/renderer.py
+++ b/src/diffusers/pipelines/shap_e/renderer.py
@@ -911,7 +911,7 @@ def decode_to_image(
n_coarse_samples=64,
n_fine_samples=128,
):
- # project the parameters from the generated latents
+ # project the the paramters from the generated latents
projected_params = self.params_proj(latents)
# update the mlp layers of the renderer
@@ -955,7 +955,7 @@ def decode_to_mesh(
query_batch_size: int = 4096,
texture_channels: Tuple = ("R", "G", "B"),
):
- # 1. project the parameters from the generated latents
+ # 1. project the the paramters from the generated latents
projected_params = self.params_proj(latents)
# 2. update the mlp layers of the renderer
diff --git a/src/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py b/src/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py
index 93af3b1189d0..5ab503df49ca 100644
--- a/src/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py
+++ b/src/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py
@@ -54,7 +54,6 @@ class SpectrogramDiffusionPipeline(DiffusionPipeline):
A scheduler to be used in combination with `decoder` to denoise the encoded audio latents.
melgan ([`OnnxRuntimeModel`]):
"""
-
_optional_components = ["melgan"]
def __init__(
diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py
index 5706298a281a..fcdca9c9f08b 100644
--- a/src/diffusers/pipelines/stable_diffusion/__init__.py
+++ b/src/diffusers/pipelines/stable_diffusion/__init__.py
@@ -55,9 +55,7 @@
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
- from ...utils.dummy_torch_and_transformers_objects import (
- StableDiffusionImageVariationPipeline,
- )
+ from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
_dummy_objects.update({"StableDiffusionImageVariationPipeline": StableDiffusionImageVariationPipeline})
else:
@@ -92,9 +90,7 @@
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
- from ...utils import (
- dummy_torch_and_transformers_and_k_diffusion_objects,
- )
+ from ...utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else:
@@ -141,32 +137,18 @@
StableDiffusionPipelineOutput,
StableDiffusionSafetyChecker,
)
- from .pipeline_stable_diffusion_attend_and_excite import (
- StableDiffusionAttendAndExcitePipeline,
- )
+ from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline
- from .pipeline_stable_diffusion_gligen_text_image import (
- StableDiffusionGLIGENTextImagePipeline,
- )
+ from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
- from .pipeline_stable_diffusion_inpaint_legacy import (
- StableDiffusionInpaintPipelineLegacy,
- )
- from .pipeline_stable_diffusion_instruct_pix2pix import (
- StableDiffusionInstructPix2PixPipeline,
- )
- from .pipeline_stable_diffusion_latent_upscale import (
- StableDiffusionLatentUpscalePipeline,
- )
+ from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
+ from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline
+ from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
- from .pipeline_stable_diffusion_model_editing import (
- StableDiffusionModelEditingPipeline,
- )
+ from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
- from .pipeline_stable_diffusion_paradigms import (
- StableDiffusionParadigmsPipeline,
- )
+ from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .pipeline_stable_unclip import StableUnCLIPPipeline
@@ -178,13 +160,9 @@
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0")):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
- from ...utils.dummy_torch_and_transformers_objects import (
- StableDiffusionImageVariationPipeline,
- )
+ from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
else:
- from .pipeline_stable_diffusion_image_variation import (
- StableDiffusionImageVariationPipeline,
- )
+ from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
try:
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.26.0")):
@@ -196,13 +174,9 @@
StableDiffusionPix2PixZeroPipeline,
)
else:
- from .pipeline_stable_diffusion_depth2img import (
- StableDiffusionDepth2ImgPipeline,
- )
+ from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline
from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
- from .pipeline_stable_diffusion_pix2pix_zero import (
- StableDiffusionPix2PixZeroPipeline,
- )
+ from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline
try:
if not (
@@ -215,9 +189,7 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
else:
- from .pipeline_stable_diffusion_k_diffusion import (
- StableDiffusionKDiffusionPipeline,
- )
+ from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
try:
if not (is_transformers_available() and is_onnx_available()):
@@ -225,22 +197,11 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_onnx_objects import *
else:
- from .pipeline_onnx_stable_diffusion import (
- OnnxStableDiffusionPipeline,
- StableDiffusionOnnxPipeline,
- )
- from .pipeline_onnx_stable_diffusion_img2img import (
- OnnxStableDiffusionImg2ImgPipeline,
- )
- from .pipeline_onnx_stable_diffusion_inpaint import (
- OnnxStableDiffusionInpaintPipeline,
- )
- from .pipeline_onnx_stable_diffusion_inpaint_legacy import (
- OnnxStableDiffusionInpaintPipelineLegacy,
- )
- from .pipeline_onnx_stable_diffusion_upscale import (
- OnnxStableDiffusionUpscalePipeline,
- )
+ from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
+ from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
+ from .pipeline_onnx_stable_diffusion_inpaint import OnnxStableDiffusionInpaintPipeline
+ from .pipeline_onnx_stable_diffusion_inpaint_legacy import OnnxStableDiffusionInpaintPipelineLegacy
+ from .pipeline_onnx_stable_diffusion_upscale import OnnxStableDiffusionUpscalePipeline
try:
if not (is_transformers_available() and is_flax_available()):
@@ -249,12 +210,8 @@
from ...utils.dummy_flax_objects import *
else:
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
- from .pipeline_flax_stable_diffusion_img2img import (
- FlaxStableDiffusionImg2ImgPipeline,
- )
- from .pipeline_flax_stable_diffusion_inpaint import (
- FlaxStableDiffusionInpaintPipeline,
- )
+ from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
+ from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline
from .pipeline_output import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
index 35466f008f54..e97f66bbcb24 100644
--- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
@@ -324,7 +324,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions
- if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
+ if "num_classes" in unet_params and type(unet_params.num_classes) == int:
config["num_class_embeds"] = unet_params.num_classes
if controlnet:
@@ -787,12 +787,7 @@ def _copy_layers(hf_layers, pt_layers):
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
if text_encoder is None:
config_name = "openai/clip-vit-large-patch14"
- try:
- config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
- except Exception:
- raise ValueError(
- f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
- )
+ config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
@@ -927,12 +922,7 @@ def convert_open_clip_checkpoint(
# text_model = CLIPTextModelWithProjection.from_pretrained(
# "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
# )
- try:
- config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
- except Exception:
- raise ValueError(
- f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: '{config_name}'."
- )
+ config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs, local_files_only=local_files_only)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
@@ -1145,7 +1135,6 @@ def download_from_original_stable_diffusion_ckpt(
stable_unclip_prior: Optional[str] = None,
clip_stats_path: Optional[str] = None,
controlnet: Optional[bool] = None,
- adapter: Optional[bool] = None,
load_safety_checker: bool = True,
pipeline_class: DiffusionPipeline = None,
local_files_only=False,
@@ -1232,11 +1221,13 @@ def download_from_original_stable_diffusion_ckpt(
StableDiffusionPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionXLImg2ImgPipeline,
- StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
)
+ if pipeline_class is None:
+ pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
+
if prediction_type == "v-prediction":
prediction_type = "v_prediction"
@@ -1331,13 +1322,6 @@ def download_from_original_stable_diffusion_ckpt(
if image_size is None:
image_size = 1024
- if pipeline_class is None:
- # Check if we have a SDXL or SD model and initialize default pipeline
- if model_type not in ["SDXL", "SDXL-Refiner"]:
- pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
- else:
- pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline
-
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
num_in_channels = 9
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
@@ -1480,19 +1464,11 @@ def download_from_original_stable_diffusion_ckpt(
config_name = "stabilityai/stable-diffusion-2"
config_kwargs = {"subfolder": "text_encoder"}
- text_model = convert_open_clip_checkpoint(
- checkpoint, config_name, local_files_only=local_files_only, **config_kwargs
+ text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
+ tokenizer = CLIPTokenizer.from_pretrained(
+ "stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only
)
- try:
- tokenizer = CLIPTokenizer.from_pretrained(
- "stabilityai/stable-diffusion-2", subfolder="tokenizer", local_files_only=local_files_only
- )
- except Exception:
- raise ValueError(
- f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'stabilityai/stable-diffusion-2'."
- )
-
if stable_unclip is None:
if controlnet:
pipe = pipeline_class(
@@ -1570,14 +1546,9 @@ def download_from_original_stable_diffusion_ckpt(
karlo_model, subfolder="prior", local_files_only=local_files_only
)
- try:
- prior_tokenizer = CLIPTokenizer.from_pretrained(
- "openai/clip-vit-large-patch14", local_files_only=local_files_only
- )
- except Exception:
- raise ValueError(
- f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
- )
+ prior_tokenizer = CLIPTokenizer.from_pretrained(
+ "openai/clip-vit-large-patch14", local_files_only=local_files_only
+ )
prior_text_model = CLIPTextModelWithProjection.from_pretrained(
"openai/clip-vit-large-patch14", local_files_only=local_files_only
)
@@ -1610,22 +1581,10 @@ def download_from_original_stable_diffusion_ckpt(
raise NotImplementedError(f"unknown `stable_unclip` type: {stable_unclip}")
elif model_type == "PaintByExample":
vision_model = convert_paint_by_example_checkpoint(checkpoint)
- try:
- tokenizer = CLIPTokenizer.from_pretrained(
- "openai/clip-vit-large-patch14", local_files_only=local_files_only
- )
- except Exception:
- raise ValueError(
- f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
- )
- try:
- feature_extractor = AutoFeatureExtractor.from_pretrained(
- "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
- )
- except Exception:
- raise ValueError(
- f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'."
- )
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
+ "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
+ )
pipe = PaintByExamplePipeline(
vae=vae,
image_encoder=vision_model,
@@ -1638,16 +1597,11 @@ def download_from_original_stable_diffusion_ckpt(
text_model = convert_ldm_clip_checkpoint(
checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
)
- try:
- tokenizer = (
- CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
- if tokenizer is None
- else tokenizer
- )
- except Exception:
- raise ValueError(
- f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
- )
+ tokenizer = (
+ CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
+ if tokenizer is None
+ else tokenizer
+ )
if load_safety_checker:
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
@@ -1683,33 +1637,18 @@ def download_from_original_stable_diffusion_ckpt(
)
elif model_type in ["SDXL", "SDXL-Refiner"]:
if model_type == "SDXL":
- try:
- tokenizer = CLIPTokenizer.from_pretrained(
- "openai/clip-vit-large-patch14", local_files_only=local_files_only
- )
- except Exception:
- raise ValueError(
- f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
- )
+ tokenizer = CLIPTokenizer.from_pretrained(
+ "openai/clip-vit-large-patch14", local_files_only=local_files_only
+ )
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
- try:
- tokenizer_2 = CLIPTokenizer.from_pretrained(
- "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
- )
- except Exception:
- raise ValueError(
- f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
- )
+ tokenizer_2 = CLIPTokenizer.from_pretrained(
+ "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
+ )
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
text_encoder_2 = convert_open_clip_checkpoint(
- checkpoint,
- config_name,
- prefix="conditioner.embedders.1.model.",
- has_projection=True,
- local_files_only=local_files_only,
- **config_kwargs,
+ checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
)
if is_accelerate_available(): # SBM Now move model to cpu.
@@ -1729,18 +1668,6 @@ def download_from_original_stable_diffusion_ckpt(
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
- elif adapter:
- pipe = pipeline_class(
- vae=vae,
- text_encoder=text_encoder,
- tokenizer=tokenizer,
- text_encoder_2=text_encoder_2,
- tokenizer_2=tokenizer_2,
- unet=unet,
- adapter=adapter,
- scheduler=scheduler,
- force_zeros_for_empty_prompt=True,
- )
else:
pipe = pipeline_class(
vae=vae,
@@ -1755,23 +1682,14 @@ def download_from_original_stable_diffusion_ckpt(
else:
tokenizer = None
text_encoder = None
- try:
- tokenizer_2 = CLIPTokenizer.from_pretrained(
- "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
- )
- except Exception:
- raise ValueError(
- f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k' with `pad_token` set to '!'."
- )
+ tokenizer_2 = CLIPTokenizer.from_pretrained(
+ "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!", local_files_only=local_files_only
+ )
+
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
config_kwargs = {"projection_dim": 1280}
text_encoder_2 = convert_open_clip_checkpoint(
- checkpoint,
- config_name,
- prefix="conditioner.embedders.0.model.",
- has_projection=True,
- local_files_only=local_files_only,
- **config_kwargs,
+ checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs
)
if is_accelerate_available(): # SBM Now move model to cpu.
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
index e5c2c78720d5..d45e35d5cba0 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
@@ -61,20 +61,6 @@ def preprocess(image):
return image
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta):
# 1. get previous step value (=t-1)
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
@@ -162,7 +148,6 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
@@ -453,36 +438,25 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs(
- self,
- prompt,
- strength,
- callback_steps,
- negative_prompt=None,
- prompt_embeds=None,
- negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -581,12 +555,11 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
if isinstance(generator, list):
init_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
- for i in range(image.shape[0])
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
- init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = self.vae.config.scaling_factor * init_latents
@@ -934,7 +907,6 @@ def __call__(
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
- self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
index 5598477c9238..bcf2a6217772 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
@@ -410,13 +410,13 @@ def __call__(
images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
- images = np.asarray(images).copy()
+ images = np.asarray(images)
# block images
if any(has_nsfw_concept):
for i, is_nsfw in enumerate(has_nsfw_concept):
if is_nsfw:
- images[i, 0] = np.asarray(images_uint8_casted[i])
+ images[i] = np.asarray(images_uint8_casted[i])
images = images.reshape(num_devices, batch_size, height, width, 3)
else:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
index aff99b43fa4f..055d9b02c15d 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
@@ -33,7 +33,10 @@
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64
def preprocess(image):
- deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
+ deprecation_message = (
+ "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use"
+ " VaeImageProcessor.preprocess(...) instead"
+ )
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
if isinstance(image, torch.Tensor):
return image
@@ -82,7 +85,6 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
-
vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
index b3dcc899c48f..88d300c10b55 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
@@ -80,7 +80,6 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
-
vae_encoder: OnnxRuntimeModel
vae_decoder: OnnxRuntimeModel
text_encoder: OnnxRuntimeModel
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
index 40abc477e7c0..fece365af49b 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py
@@ -66,7 +66,6 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
-
_optional_components = ["safety_checker", "feature_extractor"]
_is_onnx = True
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index bf43c043490b..a9d28144e543 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -17,11 +17,11 @@
import torch
from packaging import version
-from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
-from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -70,53 +70,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
- must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
-
-
-class StableDiffusionPipeline(
- DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
-):
+class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -128,7 +82,6 @@ class StableDiffusionPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
@@ -149,11 +102,9 @@ class StableDiffusionPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
- _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -164,7 +115,6 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
- image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -241,7 +191,6 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
- image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -485,23 +434,10 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
- def encode_image(self, image, device, num_images_per_prompt):
- dtype = next(self.image_encoder.parameters()).dtype
-
- if not isinstance(image, torch.Tensor):
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
-
- image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
-
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
-
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -553,22 +489,17 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
@@ -639,62 +570,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
- """
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
-
- Args:
- timesteps (`torch.Tensor`):
- generate embedding vectors at these timesteps
- embedding_dim (`int`, *optional*, defaults to 512):
- dimension of the embeddings to generate
- dtype:
- data type of the generated embeddings
-
- Returns:
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
- """
- assert len(w.shape) == 1
- w = w * 1000.0
-
- half_dim = embedding_dim // 2
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
- emb = w.to(dtype)[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1))
- assert emb.shape == (w.shape[0], embedding_dim)
- return emb
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def guidance_rescale(self):
- return self._guidance_rescale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -703,7 +578,6 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -712,15 +586,13 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -735,10 +607,6 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -763,12 +631,17 @@ def __call__(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -779,15 +652,6 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
@@ -798,23 +662,6 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
- )
-
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -822,21 +669,9 @@ def __call__(
# 1. Check inputs. Raise error if not correct
self.check_inputs(
- prompt,
- height,
- width,
- callback_steps,
- negative_prompt,
- prompt_embeds,
- negative_prompt_embeds,
- callback_on_step_end_tensor_inputs,
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
- self._guidance_scale = guidance_scale
- self._guidance_rescale = guidance_rescale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -846,37 +681,34 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
- lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
- )
+ lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
-
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
-
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -894,24 +726,12 @@ def __call__(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 6.1 Add image embeds for IP-Adapter
- added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
-
- # 6.2 Optionally get Guidance Scale Embedding
- timestep_cond = None
- if self.unet.config.time_cond_proj_dim is not None:
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
- timestep_cond = self.get_guidance_scale_embedding(
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
- ).to(device=device, dtype=latents.dtype)
-
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
@@ -919,34 +739,22 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- timestep_cond=timestep_cond,
- cross_attention_kwargs=self.cross_attention_kwargs,
- added_cond_kwargs=added_cond_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
- if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -955,9 +763,7 @@ def __call__(
callback(step_idx, t, latents)
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
- 0
- ]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
index 5950139fd6e1..153efae876cd 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
@@ -196,7 +196,6 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -470,7 +469,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -1028,7 +1027,6 @@ def __call__(
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
- self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
index e431fee7bdb0..d73cf769e3ae 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
@@ -36,20 +36,6 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
@@ -99,9 +85,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "depth_mask"]
def __init__(
self,
@@ -359,7 +343,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -410,30 +394,19 @@ def prepare_extra_step_kwargs(self, generator, eta):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs(
- self,
- prompt,
- strength,
- callback_steps,
- negative_prompt=None,
- prompt_embeds=None,
- negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -493,12 +466,11 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
elif isinstance(generator, list):
init_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
- for i in range(batch_size)
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
- init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = self.vae.config.scaling_factor * init_latents
@@ -573,29 +545,6 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui
depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map
return depth_map
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
def __call__(
self,
@@ -613,11 +562,10 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -665,21 +613,18 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
```py
@@ -708,23 +653,6 @@ def __call__(
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
-
# 1. Check inputs
self.check_inputs(
prompt,
@@ -733,13 +661,8 @@ def __call__(
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
- self._guidance_scale = guidance_scale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
-
if image is None:
raise ValueError("`image` input cannot be undefined.")
@@ -752,26 +675,30 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare depth mask
@@ -779,7 +706,7 @@ def __call__(
image,
depth_map,
batch_size * num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
prompt_embeds.dtype,
device,
)
@@ -802,11 +729,10 @@ def __call__(
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1)
@@ -815,29 +741,18 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- cross_attention_kwargs=self.cross_attention_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
- depth_mask = callback_outputs.pop("depth_mask", depth_mask)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -851,7 +766,6 @@ def __call__(
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
- self.maybe_free_model_hooks()
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py
index 3d48c811cdf1..451ef690a759 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py
@@ -273,7 +273,6 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -615,7 +614,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
index b85f40a54579..ce7faaed2ab1 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py
@@ -125,7 +125,6 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
_optional_components = ["safety_checker", "feature_extractor"]
model_cpu_offload_seq = "text_encoder->unet->vae"
_exclude_from_cpu_offload = ["safety_checker"]
@@ -412,7 +411,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -865,8 +864,9 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
- # Offload all models
- self.maybe_free_model_hooks()
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py
index 405097248e2a..67f3fe0e9448 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py
@@ -177,7 +177,6 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -437,7 +436,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -484,22 +483,17 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
@@ -1037,8 +1031,9 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
- # Offload all models
- self.maybe_free_model_hooks()
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
index be19b74ab438..c6797a0693cc 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
@@ -62,7 +62,6 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
# TODO: feature_extractor is required to encode images (if they are in PIL format),
# we should give a descriptive message if the pipeline doesn't have one.
_optional_components = ["safety_checker"]
@@ -440,8 +439,6 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
- self.maybe_free_model_hooks()
-
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index e3a1a0ed3660..2532c15696e4 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -19,11 +19,11 @@
import PIL.Image
import torch
from packaging import version
-from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -73,19 +73,6 @@
"""
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
def preprocess(image):
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
@@ -109,53 +96,8 @@ def preprocess(image):
return image
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
- must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
-
-
class StableDiffusionImg2ImgPipeline(
- DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-guided image-to-image generation using Stable Diffusion.
@@ -168,7 +110,6 @@ class StableDiffusionImg2ImgPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
@@ -189,11 +130,9 @@ class StableDiffusionImg2ImgPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
- _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -204,7 +143,6 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
- image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -281,7 +219,6 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
- image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -498,24 +435,10 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
- dtype = next(self.image_encoder.parameters()).dtype
-
- if not isinstance(image, torch.Tensor):
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
-
- image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
-
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -562,30 +485,19 @@ def prepare_extra_step_kwargs(self, generator, eta):
return extra_step_kwargs
def check_inputs(
- self,
- prompt,
- strength,
- callback_steps,
- negative_prompt=None,
- prompt_embeds=None,
- negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -643,12 +555,11 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
elif isinstance(generator, list):
init_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
- for i in range(batch_size)
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
- init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = self.vae.config.scaling_factor * init_latents
@@ -707,58 +618,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
- """
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
-
- Args:
- timesteps (`torch.Tensor`):
- generate embedding vectors at these timesteps
- embedding_dim (`int`, *optional*, defaults to 512):
- dimension of the embeddings to generate
- dtype:
- data type of the generated embeddings
-
- Returns:
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
- """
- assert len(w.shape) == 1
- w = w * 1000.0
-
- half_dim = embedding_dim // 2
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
- emb = w.to(dtype)[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1))
- assert emb.shape == (w.shape[0], embedding_dim)
- return emb
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -767,7 +626,6 @@ def __call__(
image: PipelineImageInput = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
- timesteps: List[int] = None,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -775,14 +633,12 @@ def __call__(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: int = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -805,10 +661,6 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -829,27 +681,23 @@ def __call__(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
@@ -859,37 +707,8 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
-
# 1. Check inputs. Raise error if not correct
- self.check_inputs(
- prompt,
- strength,
- callback_steps,
- negative_prompt,
- prompt_embeds,
- negative_prompt_embeds,
- callback_on_step_end_tensor_inputs,
- )
-
- self._guidance_scale = guidance_scale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
+ self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -898,75 +717,55 @@ def __call__(
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
-
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
-
# 4. Preprocess image
image = self.image_processor.preprocess(image)
# 5. set timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
latents = self.prepare_latents(
- image,
- latent_timestep,
- batch_size,
- num_images_per_prompt,
- prompt_embeds.dtype,
- device,
- generator,
+ image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 7.1 Add image embeds for IP-Adapter
- added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
-
- # 7.2 Optionally get Guidance Scale Embedding
- timestep_cond = None
- if self.unet.config.time_cond_proj_dim is not None:
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
- timestep_cond = self.get_guidance_scale_embedding(
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
- ).to(device=device, dtype=latents.dtype)
-
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
@@ -974,30 +773,18 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- timestep_cond=timestep_cond,
- cross_attention_kwargs=self.cross_attention_kwargs,
- added_cond_kwargs=added_cond_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -1006,9 +793,7 @@ def __call__(
callback(step_idx, t, latents)
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
- 0
- ]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 3570eaa6fd3d..c6361c616653 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -19,11 +19,11 @@
import PIL.Image
import torch
from packaging import version
-from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -159,67 +159,8 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
- must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
-
-
class StableDiffusionInpaintPipeline(
- DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion.
@@ -231,7 +172,6 @@ class StableDiffusionInpaintPipeline(
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]):
@@ -252,11 +192,9 @@ class StableDiffusionInpaintPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
- _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "mask", "masked_image_latents"]
def __init__(
self,
@@ -267,7 +205,6 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
- image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -349,7 +286,6 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
- image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -569,24 +505,10 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
- dtype = next(self.image_encoder.parameters()).dtype
-
- if not isinstance(image, torch.Tensor):
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
-
- image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
-
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -630,7 +552,6 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -638,19 +559,14 @@ def check_inputs(
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -738,12 +654,12 @@ def prepare_latents(
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
- image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
image_latents = self.vae.config.scaling_factor * image_latents
@@ -832,58 +748,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
- """
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
-
- Args:
- timesteps (`torch.Tensor`):
- generate embedding vectors at these timesteps
- embedding_dim (`int`, *optional*, defaults to 512):
- dimension of the embeddings to generate
- dtype:
- data type of the generated embeddings
-
- Returns:
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
- """
- assert len(w.shape) == 1
- w = w * 1000.0
-
- half_dim = embedding_dim // 2
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
- emb = w.to(dtype)[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1))
- assert emb.shape == (w.shape[0], embedding_dim)
- return emb
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
def __call__(
self,
@@ -895,7 +759,6 @@ def __call__(
width: Optional[int] = None,
strength: float = 1.0,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -904,14 +767,12 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: int = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -946,10 +807,6 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -974,27 +831,23 @@ def __call__(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
```py
@@ -1033,23 +886,6 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
-
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -1064,13 +900,8 @@ def __call__(
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
- callback_on_step_end_tensor_inputs,
)
- self._guidance_scale = guidance_scale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1080,6 +911,10 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
@@ -1089,26 +924,21 @@ def __call__(
prompt,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
-
# 4. set timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps=num_inference_steps, strength=strength, device=device
)
@@ -1171,7 +1001,7 @@ def __call__(
prompt_embeds.dtype,
device,
generator,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
)
# 8. Check that sizes of mask, masked image and latents match
@@ -1195,24 +1025,12 @@ def __call__(
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 9.1 Add image embeds for IP-Adapter
- added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
-
- # 9.2 Optionally get Guidance Scale Embedding
- timestep_cond = None
- if self.unet.config.time_cond_proj_dim is not None:
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
- timestep_cond = self.get_guidance_scale_embedding(
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
- ).to(device=device, dtype=latents.dtype)
-
# 10. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1225,22 +1043,20 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- timestep_cond=timestep_cond,
- cross_attention_kwargs=self.cross_attention_kwargs,
- added_cond_kwargs=added_cond_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if num_channels_unet == 4:
init_latents_proper = image_latents
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
init_mask, _ = mask.chunk(2)
else:
init_mask = mask
@@ -1253,18 +1069,6 @@ def __call__(
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
- mask = callback_outputs.pop("mask", mask)
- masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -1280,9 +1084,7 @@ def __call__(
init_image = self._encode_vae_image(init_image, generator=generator)
mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
- image = self.vae.decode(
- latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs
- )[0]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
index 15e6f60569a3..513c660c30cf 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py
@@ -115,7 +115,6 @@ class StableDiffusionInpaintPipelineLegacy(
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -428,7 +427,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -479,30 +478,19 @@ def prepare_extra_step_kwargs(self, generator, eta):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs(
- self,
- prompt,
- strength,
- callback_steps,
- negative_prompt=None,
- prompt_embeds=None,
- negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
index d922803858b0..e3b7e34232d1 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -13,7 +13,7 @@
# limitations under the License.
import inspect
-from typing import Callable, Dict, List, Optional, Union
+from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image
@@ -58,20 +58,6 @@ def preprocess(image):
return image
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
@@ -103,11 +89,9 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "image_latents"]
def __init__(
self,
@@ -168,9 +152,8 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
):
r"""
The call function to the pipeline for generation.
@@ -218,15 +201,12 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
Examples:
@@ -264,34 +244,8 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
-
# 0. Check inputs
- self.check_inputs(
- prompt,
- callback_steps,
- negative_prompt,
- prompt_embeds,
- negative_prompt_embeds,
- callback_on_step_end_tensor_inputs,
- )
- self._guidance_scale = guidance_scale
- self._image_guidance_scale = image_guidance_scale
+ self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
if image is None:
raise ValueError("`image` input cannot be undefined.")
@@ -305,6 +259,10 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0
# check if scheduler is in sigmas space
scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas")
@@ -313,7 +271,7 @@ def __call__(
prompt,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@@ -333,7 +291,8 @@ def __call__(
num_images_per_prompt,
prompt_embeds.dtype,
device,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
+ generator,
)
height, width = image_latents.shape[-2:]
@@ -369,13 +328,12 @@ def __call__(
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Expand the latents if we are doing classifier free guidance.
# The latents are expanded 3 times because for pix2pix the guidance\
# is applied for both the text and the input image.
- latent_model_input = torch.cat([latents] * 3) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents
# concat latents, image_latents in the channel dimension
scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -396,12 +354,12 @@ def __call__(
noise_pred = latent_model_input - sigma * noise_pred
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
noise_pred = (
noise_pred_uncond
- + self.guidance_scale * (noise_pred_text - noise_pred_image)
- + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond)
+ + guidance_scale * (noise_pred_text - noise_pred_image)
+ + image_guidance_scale * (noise_pred_image - noise_pred_uncond)
)
# Hack:
@@ -416,17 +374,6 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
- image_latents = callback_outputs.pop("image_latents", image_latents)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -649,27 +596,16 @@ def decode_latents(self, latents):
return image
def check_inputs(
- self,
- prompt,
- callback_steps,
- negative_prompt=None,
- prompt_embeds=None,
- negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
+ self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -729,7 +665,17 @@ def prepare_image_latents(
if image.shape[1] == 4:
image_latents = image
else:
- image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.mode()
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand image_latents for batch_size
@@ -782,22 +728,3 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def image_guidance_scale(self):
- return self._image_guidance_scale
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
index 388e5a4b5ebd..e0bb9b6e0b14 100755
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py
@@ -80,7 +80,6 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -342,7 +341,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -383,22 +382,17 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
index cfbbb7aaab72..1e8c98c44750 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
@@ -79,7 +79,6 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, FromSingleFileMixi
scheduler ([`SchedulerMixin`]):
A [`EulerDiscreteScheduler`] to be used in combination with `unet` to denoise the encoded image latents.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
def __init__(
@@ -512,8 +511,6 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type)
- self.maybe_free_model_hooks()
-
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py
index f410c08a3bbe..2e514a55108c 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py
@@ -115,7 +115,6 @@ class StableDiffusionLDM3DPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -406,7 +405,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -453,22 +452,17 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py
index c6364891e445..6c78d190d97f 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py
@@ -66,7 +66,6 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
with_augs ([`list`]):
Textual augmentations to apply while editing the text-to-image model. Set to `[]` for no augmentations.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -375,7 +374,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -434,22 +433,17 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
index ff6a66ab57c9..bac1f83fb336 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
@@ -85,7 +85,6 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -359,7 +358,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -431,22 +430,17 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
@@ -803,8 +797,6 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
- self.maybe_free_model_hooks()
-
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py
index f0368b4ca305..161f656fee2e 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py
@@ -96,7 +96,6 @@ class StableDiffusionParadigmsPipeline(
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -390,7 +389,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -437,22 +436,17 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
index df9849ead723..6d4286a04686 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py
@@ -310,7 +310,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the
pipeline publicly.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = [
"safety_checker",
@@ -580,7 +579,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
index 68652e977c5d..6a78d4da4545 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
@@ -124,7 +124,6 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -382,7 +381,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -441,22 +440,17 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
@@ -742,8 +736,6 @@ def get_map_size(module, input, output):
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
- self.maybe_free_model_hooks()
-
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
index ceb316331b38..f3d92119b8d2 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
@@ -92,7 +92,6 @@ class StableDiffusionUpscalePipeline(
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["watermarker", "safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -373,7 +372,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -821,8 +820,9 @@ def __call__(
if output_type == "pil" and self.watermarker is not None:
image = self.watermarker.apply_watermark(image)
- # Offload all models
- self.maybe_free_model_hooks()
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
index eb4542888c1f..3bce80fdb5b1 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
@@ -206,15 +206,17 @@ def _encode_prior_prompt(
prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device))
prompt_embeds = prior_text_encoder_output.text_embeds
- text_enc_hid_states = prior_text_encoder_output.last_hidden_state
+ prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
- prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
+ prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
- text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
+ prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
@@ -233,7 +235,9 @@ def _encode_prior_prompt(
)
negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds
- uncond_text_enc_hid_states = negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
+ uncond_prior_text_encoder_hidden_states = (
+ negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
+ )
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -241,9 +245,11 @@ def _encode_prior_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
- seq_len = uncond_text_enc_hid_states.shape[1]
- uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
- uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
+ seq_len = uncond_prior_text_encoder_hidden_states.shape[1]
+ uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat(
+ 1, num_images_per_prompt, 1
+ )
+ uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
@@ -254,11 +260,13 @@ def _encode_prior_prompt(
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
+ prior_text_encoder_hidden_states = torch.cat(
+ [uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states]
+ )
text_mask = torch.cat([uncond_text_mask, text_mask])
- return prompt_embeds, text_enc_hid_states, text_mask
+ return prompt_embeds, prior_text_encoder_hidden_states, text_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
@@ -471,7 +479,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -934,8 +942,9 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type)
- # Offload all models
- self.maybe_free_model_hooks()
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
index 73638fdd15da..a17a674b7066 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
@@ -433,7 +433,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -839,8 +839,9 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type)
- # Offload all models
- self.maybe_free_model_hooks()
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
index eb24cbfd947b..12f4551d9de3 100644
--- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
+++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
@@ -364,22 +364,17 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index 40c981a46d48..55bf929a2ee2 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -16,18 +16,11 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
-from transformers import (
- CLIPImageProcessor,
- CLIPTextModel,
- CLIPTextModelWithProjection,
- CLIPTokenizer,
- CLIPVisionModelWithProjection,
-)
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
-from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...image_processor import VaeImageProcessor
from ...loaders import (
FromSingleFileMixin,
- IPAdapterMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
@@ -42,7 +35,6 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
- deprecate,
is_invisible_watermark_available,
is_torch_xla_available,
logging,
@@ -100,57 +92,8 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
- must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
-
-
class StableDiffusionXLPipeline(
- DiffusionPipeline,
- FromSingleFileMixin,
- StableDiffusionXLLoraLoaderMixin,
- TextualInversionLoaderMixin,
- IPAdapterMixin,
+ DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL.
@@ -196,25 +139,7 @@ class StableDiffusionXLPipeline(
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
watermarker will be used.
"""
-
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
- _optional_components = [
- "tokenizer",
- "tokenizer_2",
- "text_encoder",
- "text_encoder_2",
- "image_encoder",
- "feature_extractor",
- ]
- _callback_tensor_inputs = [
- "latents",
- "prompt_embeds",
- "negative_prompt_embeds",
- "add_text_embeds",
- "add_time_ids",
- "negative_pooled_prompt_embeds",
- "negative_add_time_ids",
- ]
def __init__(
self,
@@ -225,8 +150,6 @@ def __init__(
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
- image_encoder: CLIPVisionModelWithProjection = None,
- feature_extractor: CLIPImageProcessor = None,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
@@ -240,13 +163,10 @@ def __init__(
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
- image_encoder=image_encoder,
- feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
-
self.default_sample_size = self.unet.config.sample_size
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -355,17 +275,12 @@ def encode_prompt(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
- if self.text_encoder is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
- else:
- scale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
- else:
- scale_lora_layers(self.text_encoder_2, lora_scale)
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -481,11 +396,7 @@ def encode_prompt(
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
- if self.text_encoder_2 is not None:
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -494,12 +405,7 @@ def encode_prompt(
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
-
- if self.text_encoder_2 is not None:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -511,32 +417,13 @@ def encode_prompt(
bs_embed * num_images_per_prompt, -1
)
- if self.text_encoder is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder_2, lora_scale)
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder)
+ unscale_lora_layers(self.text_encoder_2)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
- dtype = next(self.image_encoder.parameters()).dtype
-
- if not isinstance(image, torch.Tensor):
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
-
- image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
-
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -568,24 +455,18 @@ def check_inputs(
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -652,13 +533,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents
- def _get_add_time_ids(
- self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
- ):
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
passed_add_embed_dim = (
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
@@ -718,66 +597,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
- """
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
-
- Args:
- timesteps (`torch.Tensor`):
- generate embedding vectors at these timesteps
- embedding_dim (`int`, *optional*, defaults to 512):
- dimension of the embeddings to generate
- dtype:
- data type of the generated embeddings
-
- Returns:
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
- """
- assert len(w.shape) == 1
- w = w * 1000.0
-
- half_dim = embedding_dim // 2
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
- emb = w.to(dtype)[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1))
- assert emb.shape == (w.shape[0], embedding_dim)
- return emb
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def guidance_rescale(self):
- return self._guidance_rescale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def denoising_end(self):
- return self._denoising_end
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -787,7 +606,6 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -800,9 +618,10 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
@@ -812,9 +631,6 @@ def __call__(
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -839,10 +655,6 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will
@@ -889,13 +701,18 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -934,15 +751,6 @@ def __call__(
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
@@ -951,23 +759,6 @@ def __call__(
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
-
# 0. Default height and width to unet
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
@@ -988,15 +779,8 @@ def __call__(
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
- callback_on_step_end_tensor_inputs,
)
- self._guidance_scale = guidance_scale
- self._guidance_rescale = guidance_rescale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
- self._denoising_end = denoising_end
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1007,10 +791,13 @@ def __call__(
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
# 3. Encode input prompt
- lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
- )
+ lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
(
prompt_embeds,
@@ -1022,7 +809,7 @@ def __call__(
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
@@ -1030,11 +817,13 @@ def __call__(
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -1054,17 +843,8 @@ def __call__(
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
- if self.text_encoder_2 is None:
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
- else:
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
-
add_time_ids = self._get_add_time_ids(
- original_size,
- crops_coords_top_left,
- target_size,
- dtype=prompt_embeds.dtype,
- text_encoder_projection_dim=text_encoder_projection_dim,
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
)
if negative_original_size is not None and negative_target_size is not None:
negative_add_time_ids = self._get_add_time_ids(
@@ -1072,12 +852,11 @@ def __call__(
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
- text_encoder_projection_dim=text_encoder_projection_dim,
)
else:
negative_add_time_ids = add_time_ids
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
@@ -1086,89 +865,50 @@ def __call__(
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
- if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
- image_embeds = image_embeds.to(device)
-
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 8.1 Apply denoising_end
- if (
- self.denoising_end is not None
- and isinstance(self.denoising_end, float)
- and self.denoising_end > 0
- and self.denoising_end < 1
- ):
+ if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
- # 9. Optionally get Guidance Scale Embedding
- timestep_cond = None
- if self.unet.config.time_cond_proj_dim is not None:
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
- timestep_cond = self.get_guidance_scale_embedding(
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
- ).to(device=device, dtype=latents.dtype)
-
- self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
- if ip_adapter_image is not None:
- added_cond_kwargs["image_embeds"] = image_embeds
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- timestep_cond=timestep_cond,
- cross_attention_kwargs=self.cross_attention_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
- if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
- add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
- negative_pooled_prompt_embeds = callback_outputs.pop(
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
- )
- add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
- negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
index 436d816e5eb3..b436f404d5ea 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -17,21 +17,10 @@
import PIL.Image
import torch
-from transformers import (
- CLIPImageProcessor,
- CLIPTextModel,
- CLIPTextModelWithProjection,
- CLIPTokenizer,
- CLIPVisionModelWithProjection,
-)
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import (
- FromSingleFileMixin,
- IPAdapterMixin,
- StableDiffusionXLLoraLoaderMixin,
- TextualInversionLoaderMixin,
-)
+from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
@@ -43,7 +32,6 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
- deprecate,
is_invisible_watermark_available,
is_torch_xla_available,
logging,
@@ -104,71 +92,8 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
- must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
-
-
class StableDiffusionXLImg2ImgPipeline(
- DiffusionPipeline,
- TextualInversionLoaderMixin,
- FromSingleFileMixin,
- StableDiffusionXLLoraLoaderMixin,
- IPAdapterMixin,
+ DiffusionPipeline, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL.
@@ -217,25 +142,9 @@ class StableDiffusionXLImg2ImgPipeline(
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
watermarker will be used.
"""
-
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
- _optional_components = [
- "tokenizer",
- "tokenizer_2",
- "text_encoder",
- "text_encoder_2",
- "image_encoder",
- "feature_extractor",
- ]
- _callback_tensor_inputs = [
- "latents",
- "prompt_embeds",
- "negative_prompt_embeds",
- "add_text_embeds",
- "add_time_ids",
- "negative_pooled_prompt_embeds",
- "add_neg_time_ids",
- ]
+
+ _optional_components = ["tokenizer", "text_encoder"]
def __init__(
self,
@@ -246,8 +155,6 @@ def __init__(
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
- image_encoder: CLIPVisionModelWithProjection = None,
- feature_extractor: CLIPImageProcessor = None,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
@@ -261,8 +168,6 @@ def __init__(
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
- image_encoder=image_encoder,
- feature_extractor=feature_extractor,
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
@@ -377,17 +282,12 @@ def encode_prompt(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
- if self.text_encoder is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
- else:
- scale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
- else:
- scale_lora_layers(self.text_encoder_2, lora_scale)
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -503,11 +403,7 @@ def encode_prompt(
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
- if self.text_encoder_2 is not None:
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -516,12 +412,7 @@ def encode_prompt(
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
-
- if self.text_encoder_2 is not None:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -533,15 +424,10 @@ def encode_prompt(
bs_embed * num_images_per_prompt, -1
)
- if self.text_encoder is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder_2, lora_scale)
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder)
+ unscale_lora_layers(self.text_encoder_2)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
@@ -574,7 +460,6 @@ def check_inputs(
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -585,19 +470,14 @@ def check_inputs(
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
f" {type(num_inference_steps)}."
)
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -655,20 +535,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
-
- num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
- if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
- # if the scheduler is a 2nd order scheduler we might have to do +1
- # because `num_inference_steps` might be even given that every timestep
- # (except the highest one) is duplicated. If `num_inference_steps` is even it would
- # mean that we cut the timesteps in the middle of the denoising step
- # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
- # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
- num_inference_steps = num_inference_steps + 1
-
- # because t_n+1 >= t_n, we slice the timesteps starting from the end
- timesteps = timesteps[-num_inference_steps:]
- return timesteps, num_inference_steps
+ timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
+ return torch.tensor(timesteps), len(timesteps)
return timesteps, num_inference_steps - t_start
@@ -706,12 +574,11 @@ def prepare_latents(
elif isinstance(generator, list):
init_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
- for i in range(batch_size)
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
- init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
if self.vae.config.force_upcast:
self.vae.to(dtype)
@@ -740,20 +607,6 @@ def prepare_latents(
return latents
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
- dtype = next(self.image_encoder.parameters()).dtype
-
- if not isinstance(image, torch.Tensor):
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
-
- image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
-
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
-
def _get_add_time_ids(
self,
original_size,
@@ -765,7 +618,6 @@ def _get_add_time_ids(
negative_crops_coords_top_left,
negative_target_size,
dtype,
- text_encoder_projection_dim=None,
):
if self.config.requires_aesthetics_score:
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
@@ -777,7 +629,7 @@ def _get_add_time_ids(
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
passed_add_embed_dim = (
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
@@ -853,70 +705,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
- """
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
-
- Args:
- timesteps (`torch.Tensor`):
- generate embedding vectors at these timesteps
- embedding_dim (`int`, *optional*, defaults to 512):
- dimension of the embeddings to generate
- dtype:
- data type of the generated embeddings
-
- Returns:
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
- """
- assert len(w.shape) == 1
- w = w * 1000.0
-
- half_dim = embedding_dim // 2
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
- emb = w.to(dtype)[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1))
- assert emb.shape == (w.shape[0], embedding_dim)
- return emb
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def guidance_rescale(self):
- return self._guidance_rescale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def denoising_end(self):
- return self._denoising_end
-
- @property
- def denoising_start(self):
- return self._denoising_start
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -926,7 +714,6 @@ def __call__(
image: PipelineImageInput = None,
strength: float = 0.3,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
@@ -940,9 +727,10 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Tuple[int, int] = None,
@@ -954,9 +742,6 @@ def __call__(
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -980,10 +765,6 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
denoising_start (`float`, *optional*):
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
@@ -1038,13 +819,18 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -1094,15 +880,6 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
@@ -1111,23 +888,6 @@ def __call__(
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
-
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
@@ -1139,16 +899,8 @@ def __call__(
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
- callback_on_step_end_tensor_inputs,
)
- self._guidance_scale = guidance_scale
- self._guidance_rescale = guidance_rescale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
- self._denoising_end = denoising_end
- self._denoising_start = denoising_start
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1159,9 +911,14 @@ def __call__(
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
# 3. Encode input prompt
text_encoder_lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
(
prompt_embeds,
@@ -1173,7 +930,7 @@ def __call__(
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
@@ -1181,7 +938,7 @@ def __call__(
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
# 4. Preprocess image
@@ -1189,18 +946,15 @@ def __call__(
# 5. Prepare timesteps
def denoising_value_valid(dnv):
- return isinstance(self.denoising_end, float) and 0 < dnv < 1
+ return isinstance(denoising_end, float) and 0 < dnv < 1
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(
- num_inference_steps,
- strength,
- device,
- denoising_start=self.denoising_start if denoising_value_valid else None,
+ num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid else None
)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
- add_noise = True if self.denoising_start is None else False
+ add_noise = True if denoising_start is None else False
# 6. Prepare latent variables
latents = self.prepare_latents(
image,
@@ -1229,11 +983,6 @@ def denoising_value_valid(dnv):
negative_target_size = target_size
add_text_embeds = pooled_prompt_embeds
- if self.text_encoder_2 is None:
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
- else:
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
-
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
@@ -1244,11 +993,10 @@ def denoising_value_valid(dnv):
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
- text_encoder_projection_dim=text_encoder_projection_dim,
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
@@ -1258,95 +1006,61 @@ def denoising_value_valid(dnv):
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device)
- if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
- image_embeds = image_embeds.to(device)
-
# 9. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 9.1 Apply denoising_end
if (
- self.denoising_end is not None
- and self.denoising_start is not None
- and denoising_value_valid(self.denoising_end)
- and denoising_value_valid(self.denoising_start)
- and self.denoising_start >= self.denoising_end
+ denoising_end is not None
+ and denoising_start is not None
+ and denoising_value_valid(denoising_end)
+ and denoising_value_valid(denoising_start)
+ and denoising_start >= denoising_end
):
raise ValueError(
- f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
- + f" {self.denoising_end} when using type float."
+ f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {denoising_end} when using type float."
)
- elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+ elif denoising_end is not None and denoising_value_valid(denoising_end):
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
- # 9.2 Optionally get Guidance Scale Embedding
- timestep_cond = None
- if self.unet.config.time_cond_proj_dim is not None:
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
- timestep_cond = self.get_guidance_scale_embedding(
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
- ).to(device=device, dtype=latents.dtype)
-
- self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
- if ip_adapter_image is not None:
- added_cond_kwargs["image_embeds"] = image_embeds
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- timestep_cond=timestep_cond,
- cross_attention_kwargs=self.cross_attention_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
- if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
- add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
- negative_pooled_prompt_embeds = callback_outputs.pop(
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
- )
- add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
- add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index f54b680dfd7c..c04d2c0518c1 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -18,21 +18,10 @@
import numpy as np
import PIL.Image
import torch
-from transformers import (
- CLIPImageProcessor,
- CLIPTextModel,
- CLIPTextModelWithProjection,
- CLIPTokenizer,
- CLIPVisionModelWithProjection,
-)
+from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import (
- FromSingleFileMixin,
- IPAdapterMixin,
- StableDiffusionXLLoraLoaderMixin,
- TextualInversionLoaderMixin,
-)
+from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
@@ -249,71 +238,8 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
- must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
-
-
class StableDiffusionXLInpaintPipeline(
- DiffusionPipeline,
- TextualInversionLoaderMixin,
- StableDiffusionXLLoraLoaderMixin,
- FromSingleFileMixin,
- IPAdapterMixin,
+ DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL.
@@ -362,28 +288,9 @@ class StableDiffusionXLInpaintPipeline(
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
watermarker will be used.
"""
-
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
- _optional_components = [
- "tokenizer",
- "tokenizer_2",
- "text_encoder",
- "text_encoder_2",
- "image_encoder",
- "feature_extractor",
- ]
- _callback_tensor_inputs = [
- "latents",
- "prompt_embeds",
- "negative_prompt_embeds",
- "add_text_embeds",
- "add_time_ids",
- "negative_pooled_prompt_embeds",
- "add_neg_time_ids",
- "mask",
- "masked_image_latents",
- ]
+ _optional_components = ["tokenizer", "text_encoder"]
def __init__(
self,
@@ -394,8 +301,6 @@ def __init__(
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
- image_encoder: CLIPVisionModelWithProjection = None,
- feature_extractor: CLIPImageProcessor = None,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
@@ -409,8 +314,6 @@ def __init__(
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
- image_encoder=image_encoder,
- feature_extractor=feature_extractor,
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
@@ -461,20 +364,6 @@ def disable_vae_tiling(self):
"""
self.vae.disable_tiling()
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
- dtype = next(self.image_encoder.parameters()).dtype
-
- if not isinstance(image, torch.Tensor):
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
-
- image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
-
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
-
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
self,
@@ -542,17 +431,12 @@ def encode_prompt(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
- if self.text_encoder is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
- else:
- scale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
- else:
- scale_lora_layers(self.text_encoder_2, lora_scale)
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -668,11 +552,7 @@ def encode_prompt(
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
- if self.text_encoder_2 is not None:
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -681,12 +561,7 @@ def encode_prompt(
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
-
- if self.text_encoder_2 is not None:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -698,15 +573,10 @@ def encode_prompt(
bs_embed * num_images_per_prompt, -1
)
- if self.text_encoder is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder_2, lora_scale)
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder)
+ unscale_lora_layers(self.text_encoder_2)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
@@ -740,7 +610,6 @@ def check_inputs(
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -748,19 +617,14 @@ def check_inputs(
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -831,11 +695,10 @@ def prepare_latents(
if image.shape[1] == 4:
image_latents = image.to(device=device, dtype=dtype)
- image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
elif return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
- image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
if latents is None and add_noise:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -868,12 +731,12 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
- retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
- image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
if self.vae.config.force_upcast:
self.vae.to(dtype)
@@ -956,20 +819,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
-
- num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
- if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
- # if the scheduler is a 2nd order scheduler we might have to do +1
- # because `num_inference_steps` might be even given that every timestep
- # (except the highest one) is duplicated. If `num_inference_steps` is even it would
- # mean that we cut the timesteps in the middle of the denoising step
- # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
- # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
- num_inference_steps = num_inference_steps + 1
-
- # because t_n+1 >= t_n, we slice the timesteps starting from the end
- timesteps = timesteps[-num_inference_steps:]
- return timesteps, num_inference_steps
+ timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
+ return torch.tensor(timesteps), len(timesteps)
return timesteps, num_inference_steps - t_start
@@ -985,7 +836,6 @@ def _get_add_time_ids(
negative_crops_coords_top_left,
negative_target_size,
dtype,
- text_encoder_projection_dim=None,
):
if self.config.requires_aesthetics_score:
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
@@ -997,7 +847,7 @@ def _get_add_time_ids(
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
passed_add_embed_dim = (
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
@@ -1073,70 +923,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
- """
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
-
- Args:
- timesteps (`torch.Tensor`):
- generate embedding vectors at these timesteps
- embedding_dim (`int`, *optional*, defaults to 512):
- dimension of the embeddings to generate
- dtype:
- data type of the generated embeddings
-
- Returns:
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
- """
- assert len(w.shape) == 1
- w = w * 1000.0
-
- half_dim = embedding_dim // 2
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
- emb = w.to(dtype)[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1))
- assert emb.shape == (w.shape[0], embedding_dim)
- return emb
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def guidance_rescale(self):
- return self._guidance_rescale
-
- @property
- def clip_skip(self):
- return self._clip_skip
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
-
- @property
- def cross_attention_kwargs(self):
- return self._cross_attention_kwargs
-
- @property
- def denoising_end(self):
- return self._denoising_end
-
- @property
- def denoising_start(self):
- return self._denoising_start
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1150,7 +936,6 @@ def __call__(
width: Optional[int] = None,
strength: float = 0.9999,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 7.5,
@@ -1164,9 +949,10 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
- ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Tuple[int, int] = None,
@@ -1178,9 +964,6 @@ def __call__(
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -1221,10 +1004,6 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
denoising_start (`float`, *optional*):
When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
@@ -1267,7 +1046,6 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
@@ -1286,6 +1064,12 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -1330,15 +1114,6 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
@@ -1347,23 +1122,6 @@ def __call__(
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
"""
-
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
-
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -1380,16 +1138,8 @@ def __call__(
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
- callback_on_step_end_tensor_inputs,
)
- self._guidance_scale = guidance_scale
- self._guidance_rescale = guidance_rescale
- self._clip_skip = clip_skip
- self._cross_attention_kwargs = cross_attention_kwargs
- self._denoising_end = denoising_end
- self._denoising_start = denoising_start
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1399,10 +1149,14 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
(
@@ -1415,7 +1169,7 @@ def __call__(
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
@@ -1423,19 +1177,16 @@ def __call__(
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale,
- clip_skip=self.clip_skip,
+ clip_skip=clip_skip,
)
# 4. set timesteps
def denoising_value_valid(dnv):
- return isinstance(self.denoising_end, float) and 0 < dnv < 1
+ return isinstance(denoising_end, float) and 0 < dnv < 1
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(
- num_inference_steps,
- strength,
- device,
- denoising_start=self.denoising_start if denoising_value_valid else None,
+ num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid else None
)
# check that number of inference steps is not < 1 - as this doesn't make sense
if num_inference_steps < 1:
@@ -1467,7 +1218,7 @@ def denoising_value_valid(dnv):
num_channels_unet = self.unet.config.in_channels
return_image_latents = num_channels_unet == 4
- add_noise = True if self.denoising_start is None else False
+ add_noise = True if denoising_start is None else False
latents_outputs = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
@@ -1500,7 +1251,7 @@ def denoising_value_valid(dnv):
prompt_embeds.dtype,
device,
generator,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
)
# 8. Check that sizes of mask, masked image and latents match
@@ -1538,11 +1289,6 @@ def denoising_value_valid(dnv):
negative_target_size = target_size
add_text_embeds = pooled_prompt_embeds
- if self.text_encoder_2 is None:
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
- else:
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
-
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
@@ -1553,11 +1299,10 @@ def denoising_value_valid(dnv):
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
- text_encoder_projection_dim=text_encoder_projection_dim,
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
@@ -1567,49 +1312,34 @@ def denoising_value_valid(dnv):
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device)
- if ip_adapter_image is not None:
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
- if self.do_classifier_free_guidance:
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
- image_embeds = image_embeds.to(device)
-
# 11. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
if (
- self.denoising_end is not None
- and self.denoising_start is not None
- and denoising_value_valid(self.denoising_end)
- and denoising_value_valid(self.denoising_start)
- and self.denoising_start >= self.denoising_end
+ denoising_end is not None
+ and denoising_start is not None
+ and denoising_value_valid(denoising_end)
+ and denoising_value_valid(denoising_start)
+ and denoising_start >= denoising_end
):
raise ValueError(
- f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
- + f" {self.denoising_end} when using type float."
+ f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
+ + f" {denoising_end} when using type float."
)
- elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+ elif denoising_end is not None and denoising_value_valid(denoising_end):
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
- # 11.1 Optionally get Guidance Scale Embedding
- timestep_cond = None
- if self.unet.config.time_cond_proj_dim is not None:
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
- timestep_cond = self.get_guidance_scale_embedding(
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
- ).to(device=device, dtype=latents.dtype)
-
- self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1619,33 +1349,30 @@ def denoising_value_valid(dnv):
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
- if ip_adapter_image is not None:
- added_cond_kwargs["image_embeds"] = image_embeds
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- timestep_cond=timestep_cond,
- cross_attention_kwargs=self.cross_attention_kwargs,
+ cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
- if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if num_channels_unet == 4:
init_latents_proper = image_latents
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
init_mask, _ = mask.chunk(2)
else:
init_mask = mask
@@ -1658,24 +1385,6 @@ def denoising_value_valid(dnv):
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
- add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
- negative_pooled_prompt_embeds = callback_outputs.pop(
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
- )
- add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
- add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
- mask = callback_outputs.pop("mask", mask)
- masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
-
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
index b14c746f9962..8cd7f46e633a 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
@@ -31,13 +31,11 @@
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
- USE_PEFT_BACKEND,
deprecate,
is_invisible_watermark_available,
is_torch_xla_available,
logging,
replace_example_docstring,
- scale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
@@ -88,20 +86,6 @@
"""
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -165,9 +149,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
watermarker will be used.
"""
-
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
- _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
def __init__(
self,
@@ -298,17 +280,8 @@ def encode_prompt(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
- if self.text_encoder is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
- else:
- scale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
- else:
- scale_lora_layers(self.text_encoder_2, lora_scale)
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -417,8 +390,7 @@ def encode_prompt(
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
- prompt_embeds_dtype = self.text_encoder_2.dtype if self.text_encoder_2 is not None else self.unet.dtype
- prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -427,7 +399,7 @@ def encode_prompt(
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -461,27 +433,16 @@ def prepare_extra_step_kwargs(self, generator, eta):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.check_inputs
def check_inputs(
- self,
- prompt,
- callback_steps,
- negative_prompt=None,
- prompt_embeds=None,
- negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
+ self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -547,7 +508,17 @@ def prepare_image_latents(
self.upcast_vae()
image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
- image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if isinstance(generator, list):
+ image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.mode()
# cast back to fp16 if needed
if needs_upcasting:
@@ -581,13 +552,11 @@ def prepare_image_latents(
return image_latents
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
- def _get_add_time_ids(
- self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
- ):
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
passed_add_embed_dim = (
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
@@ -870,6 +839,7 @@ def __call__(
prompt_embeds.dtype,
device,
do_classifier_free_guidance,
+ generator,
)
# 7. Prepare latent variables
@@ -901,17 +871,8 @@ def __call__(
# 10. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
- if self.text_encoder_2 is None:
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
- else:
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
-
add_time_ids = self._get_add_time_ids(
- original_size,
- crops_coords_top_left,
- target_size,
- dtype=prompt_embeds.dtype,
- text_encoder_projection_dim=text_encoder_projection_dim,
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
)
if do_classifier_free_guidance:
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
index a0a17e8cacec..0c7120c5b3ec 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
@@ -118,51 +118,6 @@ def _preprocess_adapter_image(image, height, width):
return image
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
- must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
-
-
class StableDiffusionAdapterPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
@@ -197,7 +152,6 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->adapter->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
@@ -475,7 +429,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -614,8 +568,8 @@ def _default_height_width(self, height, width, image):
elif isinstance(image, torch.Tensor):
height = image.shape[-2]
- # round down to nearest multiple of `self.adapter.downscale_factor`
- height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor
+ # round down to nearest multiple of `self.adapter.total_downscale_factor`
+ height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
if width is None:
if isinstance(image, PIL.Image.Image):
@@ -623,8 +577,8 @@ def _default_height_width(self, height, width, image):
elif isinstance(image, torch.Tensor):
width = image.shape[-1]
- # round down to nearest multiple of `self.adapter.downscale_factor`
- width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor
+ # round down to nearest multiple of `self.adapter.total_downscale_factor`
+ width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
return height, width
@@ -656,46 +610,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
- """
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
-
- Args:
- timesteps (`torch.Tensor`):
- generate embedding vectors at these timesteps
- embedding_dim (`int`, *optional*, defaults to 512):
- dimension of the embeddings to generate
- dtype:
- data type of the generated embeddings
-
- Returns:
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
- """
- assert len(w.shape) == 1
- w = w * 1000.0
-
- half_dim = embedding_dim // 2
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
- emb = w.to(dtype)[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1))
- assert emb.shape == (w.shape[0], embedding_dim)
- return emb
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -705,7 +619,6 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -740,10 +653,6 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -814,8 +723,6 @@ def __call__(
prompt, height, width, callback_steps, image, negative_prompt, prompt_embeds, negative_prompt_embeds
)
- self._guidance_scale = guidance_scale
-
if isinstance(self.adapter, MultiAdapter):
adapter_input = []
@@ -835,12 +742,17 @@ def __call__(
else:
batch_size = prompt_embeds.shape[0]
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@@ -849,11 +761,12 @@ def __call__(
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -871,14 +784,6 @@ def __call__(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 6.5 Optionally get Guidance Scale Embedding
- timestep_cond = None
- if self.unet.config.time_cond_proj_dim is not None:
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
- timestep_cond = self.get_guidance_scale_embedding(
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
- ).to(device=device, dtype=latents.dtype)
-
# 7. Denoising loop
if isinstance(self.adapter, MultiAdapter):
adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
@@ -891,7 +796,7 @@ def __call__(
if num_images_per_prompt > 1:
for k, v in enumerate(adapter_state):
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
for k, v in enumerate(adapter_state):
adapter_state[k] = torch.cat([v] * 2, dim=0)
@@ -899,7 +804,7 @@ def __call__(
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
@@ -907,14 +812,12 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs,
- down_intrablock_additional_residuals=[state.clone() for state in adapter_state],
- return_dict=False,
- )[0]
+ down_block_additional_residuals=[state.clone() for state in adapter_state],
+ ).sample
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
index b07c98fef679..b31d478a9d67 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
@@ -123,51 +123,6 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
-def retrieve_timesteps(
- scheduler,
- num_inference_steps: Optional[int] = None,
- device: Optional[Union[str, torch.device]] = None,
- timesteps: Optional[List[int]] = None,
- **kwargs,
-):
- """
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
-
- Args:
- scheduler (`SchedulerMixin`):
- The scheduler to get timesteps from.
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
- device (`str` or `torch.device`, *optional*):
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- timesteps (`List[int]`, *optional*):
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
- must be `None`.
-
- Returns:
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
- second element is the number of inference steps.
- """
- if timesteps is not None:
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
- if not accepts_timesteps:
- raise ValueError(
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
- f" timestep schedules. Please check whether you are using the correct scheduler."
- )
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- num_inference_steps = len(timesteps)
- else:
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
- timesteps = scheduler.timesteps
- return timesteps, num_inference_steps
-
-
class StableDiffusionXLAdapterPipeline(
DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
):
@@ -204,9 +159,7 @@ class StableDiffusionXLAdapterPipeline(
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
-
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
- _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
def __init__(
self,
@@ -337,17 +290,12 @@ def encode_prompt(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
- if self.text_encoder is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
- else:
- scale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if not USE_PEFT_BACKEND:
- adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
- else:
- scale_lora_layers(self.text_encoder_2, lora_scale)
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -463,11 +411,7 @@ def encode_prompt(
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
- if self.text_encoder_2 is not None:
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -476,12 +420,7 @@ def encode_prompt(
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
-
- if self.text_encoder_2 is not None:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
- else:
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
-
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -493,15 +432,10 @@ def encode_prompt(
bs_embed * num_images_per_prompt, -1
)
- if self.text_encoder is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
-
- if self.text_encoder_2 is not None:
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
- # Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder_2, lora_scale)
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder)
+ unscale_lora_layers(self.text_encoder_2)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
@@ -537,24 +471,18 @@ def check_inputs(
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -622,13 +550,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
return latents
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
- def _get_add_time_ids(
- self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
- ):
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
passed_add_embed_dim = (
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
@@ -674,8 +600,8 @@ def _default_height_width(self, height, width, image):
elif isinstance(image, torch.Tensor):
height = image.shape[-2]
- # round down to nearest multiple of `self.adapter.downscale_factor`
- height = (height // self.adapter.downscale_factor) * self.adapter.downscale_factor
+ # round down to nearest multiple of `self.adapter.total_downscale_factor`
+ height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
if width is None:
if isinstance(image, PIL.Image.Image):
@@ -683,8 +609,8 @@ def _default_height_width(self, height, width, image):
elif isinstance(image, torch.Tensor):
width = image.shape[-1]
- # round down to nearest multiple of `self.adapter.downscale_factor`
- width = (width // self.adapter.downscale_factor) * self.adapter.downscale_factor
+ # round down to nearest multiple of `self.adapter.total_downscale_factor`
+ width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
return height, width
@@ -716,46 +642,6 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
- """
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
-
- Args:
- timesteps (`torch.Tensor`):
- generate embedding vectors at these timesteps
- embedding_dim (`int`, *optional*, defaults to 512):
- dimension of the embeddings to generate
- dtype:
- data type of the generated embeddings
-
- Returns:
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
- """
- assert len(w.shape) == 1
- w = w * 1000.0
-
- half_dim = embedding_dim // 2
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
- emb = w.to(dtype)[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0, 1))
- assert emb.shape == (w.shape[0], embedding_dim)
- return emb
-
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -766,7 +652,6 @@ def __call__(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
- timesteps: List[int] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -820,10 +705,6 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
- timesteps (`List[int]`, *optional*):
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
- passed will be used. Must be in descending order.
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will
@@ -973,8 +854,6 @@ def __call__(
negative_pooled_prompt_embeds,
)
- self._guidance_scale = guidance_scale
-
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -985,6 +864,11 @@ def __call__(
device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
# 3. Encode input prompt
(
prompt_embeds,
@@ -996,7 +880,7 @@ def __call__(
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
@@ -1007,7 +891,9 @@ def __call__(
)
# 4. Prepare timesteps
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+
+ timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -1025,14 +911,6 @@ def __call__(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 6.5 Optionally get Guidance Scale Embedding
- timestep_cond = None
- if self.unet.config.time_cond_proj_dim is not None:
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
- timestep_cond = self.get_guidance_scale_embedding(
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
- ).to(device=device, dtype=latents.dtype)
-
# 7. Prepare added time ids & embeddings & adapter features
if isinstance(self.adapter, MultiAdapter):
adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
@@ -1045,22 +923,13 @@ def __call__(
if num_images_per_prompt > 1:
for k, v in enumerate(adapter_state):
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
for k, v in enumerate(adapter_state):
adapter_state[k] = torch.cat([v] * 2, dim=0)
add_text_embeds = pooled_prompt_embeds
- if self.text_encoder_2 is None:
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
- else:
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
-
add_time_ids = self._get_add_time_ids(
- original_size,
- crops_coords_top_left,
- target_size,
- dtype=prompt_embeds.dtype,
- text_encoder_projection_dim=text_encoder_projection_dim,
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
)
if negative_original_size is not None and negative_target_size is not None:
negative_add_time_ids = self._get_add_time_ids(
@@ -1068,12 +937,11 @@ def __call__(
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
- text_encoder_projection_dim=text_encoder_projection_dim,
)
else:
negative_add_time_ids = add_time_ids
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
@@ -1099,7 +967,7 @@ def __call__(
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1107,27 +975,26 @@ def __call__(
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
if i < int(num_inference_steps * adapter_conditioning_factor):
- down_intrablock_additional_residuals = [state.clone() for state in adapter_state]
+ down_block_additional_residuals = [state.clone() for state in adapter_state]
else:
- down_intrablock_additional_residuals = None
+ down_block_additional_residuals = None
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
- timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
- down_intrablock_additional_residuals=down_intrablock_additional_residuals,
+ down_block_additional_residuals=down_block_additional_residuals,
)[0]
# perform guidance
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
- if self.do_classifier_free_guidance and guidance_rescale > 0.0:
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
@@ -1160,8 +1027,9 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type)
- # Offload all models
- self.maybe_free_model_hooks()
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py
index 8d8fdb92769b..9304d5c7d818 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/__init__.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/__init__.py
@@ -25,7 +25,6 @@
_import_structure["pipeline_text_to_video_synth"] = ["TextToVideoSDPipeline"]
_import_structure["pipeline_text_to_video_synth_img2img"] = ["VideoToVideoSDPipeline"]
_import_structure["pipeline_text_to_video_zero"] = ["TextToVideoZeroPipeline"]
- _import_structure["pipeline_text_to_video_zero_sdxl"] = ["TextToVideoZeroSDXLPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -39,7 +38,6 @@
from .pipeline_text_to_video_synth import TextToVideoSDPipeline
from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline
from .pipeline_text_to_video_zero import TextToVideoZeroPipeline
- from .pipeline_text_to_video_zero_sdxl import TextToVideoZeroSDXLPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
index 1f6650f58d2e..83c31596940e 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
@@ -96,7 +96,6 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
def __init__(
@@ -362,7 +361,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -417,22 +416,17 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
index 6779a7b820c2..f5ac19c29d14 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
@@ -79,20 +79,6 @@
"""
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
-def retrieve_latents(
- encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
-):
- if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
- return encoder_output.latent_dist.sample(generator)
- elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
- return encoder_output.latent_dist.mode()
- elif hasattr(encoder_output, "latents"):
- return encoder_output.latents
- else:
- raise AttributeError("Could not access latents of provided encoder_output")
-
-
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
# reshape to ncfhw
@@ -172,7 +158,6 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
-
model_cpu_offload_seq = "text_encoder->unet->vae"
def __init__(
@@ -438,7 +423,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
@@ -486,30 +471,19 @@ def prepare_extra_step_kwargs(self, generator, eta):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
def check_inputs(
- self,
- prompt,
- strength,
- callback_steps,
- negative_prompt=None,
- prompt_embeds=None,
- negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -561,14 +535,14 @@ def prepare_latents(self, video, timestep, batch_size, dtype, device, generator=
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
+
elif isinstance(generator, list):
init_latents = [
- retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i])
- for i in range(batch_size)
+ self.vae.encode(video[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
- init_latents = retrieve_latents(self.vae.encode(video), generator=generator)
+ init_latents = self.vae.encode(video).latent_dist.sample(generator)
init_latents = self.vae.config.scaling_factor * init_latents
@@ -803,7 +777,6 @@ def __call__(
if output_type == "latent":
return TextToVideoSDPipelineOutput(frames=latents)
- # manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
index 0f9ffbebdcf6..277726781eee 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
@@ -13,7 +13,6 @@
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import BaseOutput
-from diffusers.utils.torch_utils import randn_tensor
def rearrange_0(tensor, f):
@@ -136,7 +135,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
# Cross Frame Attention
if not is_cross_attention:
- video_length = max(1, key.size()[0] // self.batch_size)
+ video_length = key.size()[0] // self.batch_size
first_frame_index = [0] * video_length
# rearrange keys to have batch and frames in the 1st and 2nd dims respectively
@@ -184,7 +183,6 @@ class TextToVideoPipelineOutput(BaseOutput):
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
`None` if safety checking could not be performed.
"""
-
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]
@@ -340,7 +338,7 @@ def forward_loop(self, x_t0, t0, t1, generator):
x_t1:
Forward process applied to x_t0 from time t0 to t1.
"""
- eps = randn_tensor(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device)
+ eps = torch.randn(x_t0.size(), generator=generator, dtype=x_t0.dtype, device=x_t0.device)
alpha_vec = torch.prod(self.scheduler.alphas[t0:t1])
x_t1 = torch.sqrt(alpha_vec) * x_t0 + torch.sqrt(1 - alpha_vec) * eps
return x_t1
diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py
index 7bebed73c106..c4a25c865d88 100644
--- a/src/diffusers/pipelines/unclip/pipeline_unclip.py
+++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py
@@ -156,15 +156,15 @@ def _encode_prompt(
text_encoder_output = self.text_encoder(text_input_ids.to(device))
prompt_embeds = text_encoder_output.text_embeds
- text_enc_hid_states = text_encoder_output.last_hidden_state
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
- prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
+ prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
- text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
@@ -181,7 +181,7 @@ def _encode_prompt(
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
- uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -189,9 +189,9 @@ def _encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
- seq_len = uncond_text_enc_hid_states.shape[1]
- uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
- uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
@@ -202,11 +202,11 @@ def _encode_prompt(
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
- return prompt_embeds, text_enc_hid_states, text_mask
+ return prompt_embeds, text_encoder_hidden_states, text_mask
@torch.no_grad()
def __call__(
@@ -293,7 +293,7 @@ def __call__(
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
- prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt(
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
)
@@ -321,7 +321,7 @@ def __call__(
latent_model_input,
timestep=t,
proj_embedding=prompt_embeds,
- encoder_hidden_states=text_enc_hid_states,
+ encoder_hidden_states=text_encoder_hidden_states,
attention_mask=text_mask,
).predicted_image_embedding
@@ -352,10 +352,10 @@ def __call__(
# decoder
- text_enc_hid_states, additive_clip_time_embeddings = self.text_proj(
+ text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings,
prompt_embeds=prompt_embeds,
- text_encoder_hidden_states=text_enc_hid_states,
+ text_encoder_hidden_states=text_encoder_hidden_states,
do_classifier_free_guidance=do_classifier_free_guidance,
)
@@ -377,7 +377,7 @@ def __call__(
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
- text_enc_hid_states.dtype,
+ text_encoder_hidden_states.dtype,
device,
generator,
decoder_latents,
@@ -391,7 +391,7 @@ def __call__(
noise_pred = self.decoder(
sample=latent_model_input,
timestep=t,
- encoder_hidden_states=text_enc_hid_states,
+ encoder_hidden_states=text_encoder_hidden_states,
class_labels=additive_clip_time_embeddings,
attention_mask=decoder_text_mask,
).sample
diff --git a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py
index bf0a4eb475c0..9b962f6e0656 100644
--- a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py
+++ b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py
@@ -20,7 +20,7 @@ class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
prefix_length (`int`):
Max number of prefix tokens that will be supplied to the model.
prefix_inner_dim (`int`):
- The hidden size of the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the
+ The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the
CLIP text encoder.
prefix_hidden_dim (`int`, *optional*):
Hidden dim of the MLP if we encode the prefix.
diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py
index 6e97e0279350..b7829f76ec12 100644
--- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py
+++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py
@@ -6,10 +6,9 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
-from ...models.attention import FeedForward
+from ...models.attention import AdaLayerNorm, FeedForward
from ...models.attention_processor import Attention
from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
-from ...models.normalization import AdaLayerNorm
from ...models.transformer_2d import Transformer2DModelOutput
from ...utils import logging
diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
index 4f3e003de08e..0d5880ac0d4f 100644
--- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
+++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
@@ -556,7 +556,7 @@ def encode_prompt(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
- unscale_lora_layers(self.text_encoder, lora_scale)
+ unscale_lora_layers(self.text_encoder)
return prompt_embeds, negative_prompt_embeds
diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
index a940cec5e46a..2ed3deeb1225 100644
--- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
@@ -5,8 +5,6 @@
import torch.nn as nn
import torch.nn.functional as F
-from diffusers.utils import deprecate
-
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.activations import get_activation
@@ -281,7 +279,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`):
- Block type for middle of UNet, it can be one of `UNetMidBlockFlatCrossAttn`, `UNetMidBlockFlat`, or
+ Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or
`UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`):
The tuple of upsample blocks to use.
@@ -300,15 +298,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
- transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`],
[`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`].
- reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
- blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
- [`~models.unet_2d_blocks.CrossAttnDownBlockFlat`], [`~models.unet_2d_blocks.CrossAttnUpBlockFlat`],
- [`~models.unet_2d_blocks.UNetMidBlockFlatCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
@@ -342,9 +335,9 @@ class conditioning with `class_embed_type` equal to `None`.
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding.
- conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
- *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
- *optional*): The dimension of the `class_labels` input when
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
@@ -389,8 +382,7 @@ def __init__(
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
- transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
- reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
@@ -425,7 +417,10 @@ def __init__(
if num_attention_heads is not None:
raise ValueError(
- "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads`"
+ " because of a naming issue as described in"
+ " https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing"
+ " `num_attention_heads` will only be supported in diffusers v0.19."
)
# If `num_attention_heads` is not defined (which is the case for most models)
@@ -439,42 +434,45 @@ def __init__(
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ "Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`:"
+ f" {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ "Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`:"
+ f" {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ "Must provide the same number of `only_cross_attention` as `down_block_types`."
+ f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ "Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`:"
+ f" {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ "Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:"
+ f" {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
- f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ "Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`:"
+ f" {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
- f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ "Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:"
+ f" {layers_per_block}. `down_block_types`: {down_block_types}."
)
- if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
- for layer_number_per_block in transformer_layers_per_block:
- if isinstance(layer_number_per_block, list):
- raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
# input
conv_in_padding = (conv_in_kernel - 1) // 2
@@ -710,19 +708,6 @@ def __init__(
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
- elif mid_block_type == "UNetMidBlockFlat":
- self.mid_block = UNetMidBlockFlat(
- in_channels=block_out_channels[-1],
- temb_channels=blocks_time_embed_dim,
- dropout=dropout,
- num_layers=0,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_groups=norm_num_groups,
- resnet_time_scale_shift=resnet_time_scale_shift,
- add_attention=False,
- )
elif mid_block_type is None:
self.mid_block = None
else:
@@ -736,11 +721,7 @@ def __init__(
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
- reversed_transformer_layers_per_block = (
- list(reversed(transformer_layers_per_block))
- if reverse_transformer_layers_per_block is None
- else reverse_transformer_layers_per_block
- )
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
@@ -887,7 +868,8 @@ def set_default_attn_processor(self):
processor = AttnProcessor()
else:
raise ValueError(
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ "Cannot call `set_default_attn_processor` when attention processors are of type"
+ f" {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
@@ -990,7 +972,7 @@ def disable_freeu(self):
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
- if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
setattr(upsample_block, k, None)
def forward(
@@ -1005,7 +987,6 @@ def forward(
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
- down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
@@ -1050,13 +1031,6 @@ def forward(
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
- down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
- additional residuals to be added to UNet long skip connections from down blocks to up blocks for
- example from ControlNet side model(s)
- mid_block_additional_residual (`torch.Tensor`, *optional*):
- additional residual to be added to UNet mid block output, for example from ControlNet side model
- down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
- additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
@@ -1073,11 +1047,9 @@ def forward(
forward_upsample_size = False
upsample_size = None
- for dim in sample.shape[-2:]:
- if dim % default_overall_up_factor != 0:
- # Forward upsample size to force interpolation output size.
- forward_upsample_size = True
- break
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ # Forward upsample size to force interpolation output size.
+ forward_upsample_size = True
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
@@ -1155,7 +1127,8 @@ def forward(
# Kandinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires"
+ " the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
@@ -1165,12 +1138,14 @@ def forward(
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
+ " the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
+ " the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
@@ -1182,7 +1157,8 @@ def forward(
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the"
+ " keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
aug_emb = self.add_embedding(image_embs)
@@ -1190,7 +1166,8 @@ def forward(
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires"
+ " the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
hint = added_cond_kwargs.get("hint")
@@ -1208,7 +1185,8 @@ def forward(
# Kadinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which"
+ " requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
@@ -1217,19 +1195,11 @@ def forward(
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires"
+ " the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
- if "image_embeds" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
- )
- image_embeds = added_cond_kwargs.get("image_embeds")
- image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
- encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
-
# 2. pre-process
sample = self.conv_in(sample)
@@ -1246,30 +1216,15 @@ def forward(
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
- # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
- is_adapter = down_intrablock_additional_residuals is not None
- # maintain backward compatibility for legacy usage, where
- # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
- # but can only use one or the other
- if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
- deprecate(
- "T2I should not use down_block_additional_residuals",
- "1.3.0",
- "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
- and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
- for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
- standard_warn=False,
- )
- down_intrablock_additional_residuals = down_block_additional_residuals
- is_adapter = True
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
# For t2i-adapter CrossAttnDownBlockFlat
additional_residuals = {}
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
sample, res_samples = downsample_block(
hidden_states=sample,
@@ -1282,8 +1237,9 @@ def forward(
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
- sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_adapter and len(down_block_additional_residuals) > 0:
+ sample += down_block_additional_residuals.pop(0)
down_block_res_samples += res_samples
@@ -1300,25 +1256,21 @@ def forward(
# 4. mid
if self.mid_block is not None:
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
- sample = self.mid_block(
- sample,
- emb,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- cross_attention_kwargs=cross_attention_kwargs,
- encoder_attention_mask=encoder_attention_mask,
- )
- else:
- sample = self.mid_block(sample, emb)
-
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
# To support T2I-Adapter-XL
if (
is_adapter
- and len(down_intrablock_additional_residuals) > 0
- and sample.shape == down_intrablock_additional_residuals[0].shape
+ and len(down_block_additional_residuals) > 0
+ and sample.shape == down_block_additional_residuals[0].shape
):
- sample += down_intrablock_additional_residuals.pop(0)
+ sample += down_block_additional_residuals.pop(0)
if is_controlnet:
sample = sample + mid_block_additional_residual
@@ -1363,7 +1315,7 @@ def forward(
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
- unscale_lora_layers(self, lora_scale)
+ unscale_lora_layers(self)
if not return_dict:
return (sample,)
@@ -1484,6 +1436,7 @@ def forward(self, input_tensor, temb):
return output_tensor
+# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class DownBlockFlat(nn.Module):
def __init__(
self,
@@ -1497,9 +1450,9 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- output_scale_factor: float = 1.0,
- add_downsample: bool = True,
- downsample_padding: int = 1,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
):
super().__init__()
resnets = []
@@ -1536,9 +1489,7 @@ def __init__(
self.gradient_checkpointing = False
- def forward(
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ def forward(self, hidden_states, temb=None, scale: float = 1.0):
output_states = ()
for resnet in self.resnets:
@@ -1572,6 +1523,7 @@ def custom_forward(*inputs):
return hidden_states, output_states
+# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class CrossAttnDownBlockFlat(nn.Module):
def __init__(
self,
@@ -1580,22 +1532,22 @@ def __init__(
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- num_attention_heads: int = 1,
- cross_attention_dim: int = 1280,
- output_scale_factor: float = 1.0,
- downsample_padding: int = 1,
- add_downsample: bool = True,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- attention_type: str = "default",
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ attention_type="default",
):
super().__init__()
resnets = []
@@ -1603,8 +1555,6 @@ def __init__(
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * num_layers
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
@@ -1628,7 +1578,7 @@ def __init__(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
- num_layers=transformer_layers_per_block[i],
+ num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
@@ -1672,8 +1622,8 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
- additional_residuals: Optional[torch.FloatTensor] = None,
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ additional_residuals=None,
+ ):
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
@@ -1741,7 +1691,7 @@ def __init__(
prev_output_channel: int,
out_channels: int,
temb_channels: int,
- resolution_idx: Optional[int] = None,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
@@ -1749,8 +1699,8 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- output_scale_factor: float = 1.0,
- add_upsample: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
):
super().__init__()
resnets = []
@@ -1784,14 +1734,7 @@ def __init__(
self.gradient_checkpointing = False
self.resolution_idx = resolution_idx
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
- temb: Optional[torch.FloatTensor] = None,
- upsample_size: Optional[int] = None,
- scale: float = 1.0,
- ) -> torch.FloatTensor:
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
is_freeu_enabled = (
getattr(self, "s1", None)
and getattr(self, "s2", None)
@@ -1852,24 +1795,24 @@ def __init__(
out_channels: int,
prev_output_channel: int,
temb_channels: int,
- resolution_idx: Optional[int] = None,
+ resolution_idx: int = None,
dropout: float = 0.0,
num_layers: int = 1,
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- num_attention_heads: int = 1,
- cross_attention_dim: int = 1280,
- output_scale_factor: float = 1.0,
- add_upsample: bool = True,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- only_cross_attention: bool = False,
- upcast_attention: bool = False,
- attention_type: str = "default",
+ num_attention_heads=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ attention_type="default",
):
super().__init__()
resnets = []
@@ -1878,9 +1821,6 @@ def __init__(
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * num_layers
-
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -1905,7 +1845,7 @@ def __init__(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
- num_layers=transformer_layers_per_block[i],
+ num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
@@ -1946,7 +1886,7 @@ def forward(
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ ):
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_freeu_enabled = (
getattr(self, "s1", None)
@@ -2018,132 +1958,6 @@ def custom_forward(*inputs):
return hidden_states
-# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
-class UNetMidBlockFlat(nn.Module):
- """
- A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks.
-
- Args:
- in_channels (`int`): The number of input channels.
- temb_channels (`int`): The number of temporal embedding channels.
- dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
- num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
- resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
- resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
- The type of normalization to apply to the time embeddings. This can help to improve the performance of the
- model on tasks with long-range temporal dependencies.
- resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
- resnet_groups (`int`, *optional*, defaults to 32):
- The number of groups to use in the group normalization layers of the resnet blocks.
- attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
- resnet_pre_norm (`bool`, *optional*, defaults to `True`):
- Whether to use pre-normalization for the resnet blocks.
- add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
- attention_head_dim (`int`, *optional*, defaults to 1):
- Dimension of a single attention head. The number of attention heads is determined based on this value and
- the number of input channels.
- output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
-
- Returns:
- `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
- in_channels, height, width)`.
-
- """
-
- def __init__(
- self,
- in_channels: int,
- temb_channels: int,
- dropout: float = 0.0,
- num_layers: int = 1,
- resnet_eps: float = 1e-6,
- resnet_time_scale_shift: str = "default", # default, spatial
- resnet_act_fn: str = "swish",
- resnet_groups: int = 32,
- attn_groups: Optional[int] = None,
- resnet_pre_norm: bool = True,
- add_attention: bool = True,
- attention_head_dim: int = 1,
- output_scale_factor: float = 1.0,
- ):
- super().__init__()
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
- self.add_attention = add_attention
-
- if attn_groups is None:
- attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
-
- # there is always at least one resnet
- resnets = [
- ResnetBlockFlat(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- ]
- attentions = []
-
- if attention_head_dim is None:
- logger.warn(
- f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
- )
- attention_head_dim = in_channels
-
- for _ in range(num_layers):
- if self.add_attention:
- attentions.append(
- Attention(
- in_channels,
- heads=in_channels // attention_head_dim,
- dim_head=attention_head_dim,
- rescale_output_factor=output_scale_factor,
- eps=resnet_eps,
- norm_num_groups=attn_groups,
- spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
- residual_connection=True,
- bias=True,
- upcast_softmax=True,
- _from_deprecated_attn_block=True,
- )
- )
- else:
- attentions.append(None)
-
- resnets.append(
- ResnetBlockFlat(
- in_channels=in_channels,
- out_channels=in_channels,
- temb_channels=temb_channels,
- eps=resnet_eps,
- groups=resnet_groups,
- dropout=dropout,
- time_embedding_norm=resnet_time_scale_shift,
- non_linearity=resnet_act_fn,
- output_scale_factor=output_scale_factor,
- pre_norm=resnet_pre_norm,
- )
- )
-
- self.attentions = nn.ModuleList(attentions)
- self.resnets = nn.ModuleList(resnets)
-
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
- hidden_states = self.resnets[0](hidden_states, temb)
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
- if attn is not None:
- hidden_states = attn(hidden_states, temb=temb)
- hidden_states = resnet(hidden_states, temb)
-
- return hidden_states
-
-
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlatCrossAttn(nn.Module):
def __init__(
@@ -2152,19 +1966,19 @@ def __init__(
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- num_attention_heads: int = 1,
- output_scale_factor: float = 1.0,
- cross_attention_dim: int = 1280,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- upcast_attention: bool = False,
- attention_type: str = "default",
+ num_attention_heads=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ attention_type="default",
):
super().__init__()
@@ -2172,10 +1986,6 @@ def __init__(
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
- # support for variable transformer layers per block
- if isinstance(transformer_layers_per_block, int):
- transformer_layers_per_block = [transformer_layers_per_block] * num_layers
-
# there is always at least one resnet
resnets = [
ResnetBlockFlat(
@@ -2193,14 +2003,14 @@ def __init__(
]
attentions = []
- for i in range(num_layers):
+ for _ in range(num_layers):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
- num_layers=transformer_layers_per_block[i],
+ num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
@@ -2304,12 +2114,12 @@ def __init__(
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
- attention_head_dim: int = 1,
- output_scale_factor: float = 1.0,
- cross_attention_dim: int = 1280,
- skip_time_act: bool = False,
- only_cross_attention: bool = False,
- cross_attention_norm: Optional[str] = None,
+ attention_head_dim=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ skip_time_act=False,
+ only_cross_attention=False,
+ cross_attention_norm=None,
):
super().__init__()
@@ -2385,7 +2195,7 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
- ) -> torch.FloatTensor:
+ ):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
index 8f8bf260ca56..a248c25a5592 100644
--- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
+++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
@@ -58,7 +58,6 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
-
model_cpu_offload_seq = "bert->unet->vqvae"
tokenizer: CLIPTokenizer
diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
index bcad6f93ef96..4f9c0bd9f4e7 100644
--- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
@@ -52,7 +52,6 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
-
model_cpu_offload_seq = "bert->unet->vqvae"
image_feature_extractor: CLIPImageProcessor
diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
index d8f947e64af7..24ced7620350 100644
--- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
+++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
@@ -51,7 +51,6 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
-
model_cpu_offload_seq = "bert->unet->vqvae"
tokenizer: CLIPTokenizer
@@ -256,22 +255,17 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
- callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
if prompt is not None and prompt_embeds is not None:
raise ValueError(
diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py
index 00d6f01beced..b3aac39386bc 100644
--- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py
+++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py
@@ -17,8 +17,6 @@
import torch.nn as nn
from ...models.attention_processor import Attention
-from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
-from ...utils import USE_PEFT_BACKEND
class WuerstchenLayerNorm(nn.LayerNorm):
@@ -34,8 +32,7 @@ def forward(self, x):
class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep):
super().__init__()
- linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
- self.mapper = linear_cls(c_timestep, c * 2)
+ self.mapper = nn.Linear(c_timestep, c * 2)
def forward(self, x, t):
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
@@ -45,14 +42,10 @@ def forward(self, x, t):
class ResBlock(nn.Module):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
super().__init__()
-
- conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
- linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
-
- self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
+ self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
- linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
+ nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
)
def forward(self, x, x_skip=None):
@@ -80,13 +73,10 @@ def forward(self, x):
class AttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
super().__init__()
-
- linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
-
self.self_attn = self_attn
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
- self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
+ self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
def forward(self, x, kv):
kv = self.kv_mapper(kv)
diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
index a7d9e32fb6c9..9bd29b59b3af 100644
--- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
+++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
@@ -14,42 +14,25 @@
# limitations under the License.
import math
-from typing import Dict, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders import UNet2DConditionLoadersMixin
-from ...models.attention_processor import (
- ADDED_KV_ATTENTION_PROCESSORS,
- CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
- AttnAddedKVProcessor,
- AttnProcessor,
-)
-from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
from ...models.modeling_utils import ModelMixin
-from ...utils import USE_PEFT_BACKEND, is_torch_version
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
-class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
- unet_name = "prior"
- _supports_gradient_checkpointing = True
-
+class WuerstchenPrior(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
super().__init__()
- conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
- linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
-
self.c_r = c_r
- self.projection = conv_cls(c_in, c, kernel_size=1)
+ self.projection = nn.Conv2d(c_in, c, kernel_size=1)
self.cond_mapper = nn.Sequential(
- linear_cls(c_cond, c),
+ nn.Linear(c_cond, c),
nn.LeakyReLU(0.2),
- linear_cls(c, c),
+ nn.Linear(c, c),
)
self.blocks = nn.ModuleList()
@@ -59,93 +42,9 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro
self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
self.out = nn.Sequential(
WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
- conv_cls(c, c_in * 2, kernel_size=1),
+ nn.Conv2d(c, c_in * 2, kernel_size=1),
)
- self.gradient_checkpointing = False
- self.set_default_attn_processor()
-
- @property
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(
- self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
- ):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor, _remove_lora=_remove_lora)
- else:
- module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
- def set_default_attn_processor(self):
- """
- Disables custom attention processors and sets the default attention implementation.
- """
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
- processor = AttnAddedKVProcessor()
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
- processor = AttnProcessor()
- else:
- raise ValueError(
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
- )
-
- self.set_attn_processor(processor, _remove_lora=True)
-
- def _set_gradient_checkpointing(self, module, value=False):
- self.gradient_checkpointing = value
-
def gen_r_embedding(self, r, max_positions=10000):
r = r * max_positions
half_dim = self.c_r // 2
@@ -162,42 +61,12 @@ def forward(self, x, r, c):
x = self.projection(x)
c_embed = self.cond_mapper(c)
r_embed = self.gen_r_embedding(r)
-
- if self.training and self.gradient_checkpointing:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- return module(*inputs)
-
- return custom_forward
-
- if is_torch_version(">=", "1.11.0"):
- for block in self.blocks:
- if isinstance(block, AttnBlock):
- x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block), x, c_embed, use_reentrant=False
- )
- elif isinstance(block, TimestepBlock):
- x = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block), x, r_embed, use_reentrant=False
- )
- else:
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
+ for block in self.blocks:
+ if isinstance(block, AttnBlock):
+ x = block(x, c_embed)
+ elif isinstance(block, TimestepBlock):
+ x = block(x, r_embed)
else:
- for block in self.blocks:
- if isinstance(block, AttnBlock):
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed)
- elif isinstance(block, TimestepBlock):
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed)
- else:
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x)
- else:
- for block in self.blocks:
- if isinstance(block, AttnBlock):
- x = block(x, c_embed)
- elif isinstance(block, TimestepBlock):
- x = block(x, r_embed)
- else:
- x = block(x)
+ x = block(x)
a, b = self.out(x).chunk(2, dim=1)
return (x_in - a) / ((1 - b).abs() + 1e-5)
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
index ed9ce91cb292..6caa09a46ce0 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Callable, Dict, List, Optional, Union
+from typing import Callable, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler
-from ...utils import deprecate, logging, replace_example_docstring
+from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_paella_vq_model import PaellaVQModel
@@ -73,12 +73,6 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
"""
model_cpu_offload_seq = "text_encoder->decoder->vqgan"
- _callback_tensor_inputs = [
- "latents",
- "text_encoder_hidden_states",
- "negative_prompt_embeds",
- "image_embeddings",
- ]
def __init__(
self,
@@ -193,18 +187,6 @@ def encode_prompt(
# to avoid doing two forward passes
return text_encoder_hidden_states, uncond_text_encoder_hidden_states
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -220,9 +202,8 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
):
"""
Function invoked when calling the pipeline for generation.
@@ -261,15 +242,12 @@ def __call__(
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
Examples:
@@ -279,33 +257,10 @@ def __call__(
embeddings.
"""
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
-
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
# 0. Define commonly used variables
device = self._execution_device
dtype = self.decoder.dtype
- self._guidance_scale = guidance_scale
+ do_classifier_free_guidance = guidance_scale > 1.0
# 1. Check inputs. Raise error if not correct
if not isinstance(prompt, list):
@@ -314,7 +269,7 @@ def __call__(
else:
raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.")
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
if negative_prompt is not None and not isinstance(negative_prompt, list):
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
@@ -343,7 +298,7 @@ def __call__(
prompt,
device,
image_embeddings.size(0) * num_images_per_prompt,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
negative_prompt,
)
text_encoder_hidden_states = (
@@ -368,26 +323,25 @@ def __call__(
latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler)
# 6. Run denoising loop
- self._num_timesteps = len(timesteps[:-1])
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
ratio = t.expand(latents.size(0)).to(dtype)
effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
- if self.do_classifier_free_guidance
+ if do_classifier_free_guidance
else image_embeddings
)
# 7. Denoise latents
predicted_latents = self.decoder(
- torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
- r=torch.cat([ratio] * 2) if self.do_classifier_free_guidance else ratio,
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents,
+ r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio,
effnet=effnet,
clip=text_encoder_hidden_states,
)
# 8. Check for classifier free guidance and apply it
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2)
- predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
+ predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, guidance_scale)
# 9. Renoise latents to next timestep
latents = self.scheduler.step(
@@ -397,42 +351,26 @@ def __call__(
generator=generator,
).prev_sample
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- image_embeddings = callback_outputs.pop("image_embeddings", image_embeddings)
- text_encoder_hidden_states = callback_outputs.pop(
- "text_encoder_hidden_states", text_encoder_hidden_states
- )
-
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
- if output_type not in ["pt", "np", "pil", "latent"]:
- raise ValueError(
- f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
- )
-
- if not output_type == "latent":
- # 10. Scale and decode the image latents with vq-vae
- latents = self.vqgan.config.scale_factor * latents
- images = self.vqgan.decode(latents).sample.clamp(0, 1)
- if output_type == "np":
- images = images.permute(0, 2, 3, 1).cpu().numpy()
- elif output_type == "pil":
- images = images.permute(0, 2, 3, 1).cpu().numpy()
- images = self.numpy_to_pil(images)
- else:
- images = latents
+ # 10. Scale and decode the image latents with vq-vae
+ latents = self.vqgan.config.scale_factor * latents
+ images = self.vqgan.decode(latents).sample.clamp(0, 1)
# Offload all models
self.maybe_free_model_hooks()
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(f"Only the output types `pt`, `np` and `pil` are supported not output_type={output_type}")
+
+ if output_type == "np":
+ images = images.permute(0, 2, 3, 1).cpu().numpy()
+ elif output_type == "pil":
+ images = images.permute(0, 2, 3, 1).cpu().numpy()
+ images = self.numpy_to_pil(images)
+
if not return_dict:
return images
return ImagePipelineOutput(images)
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
index d4de47ba0c9e..888d3c0dd74b 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
@@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Callable, Dict, List, Optional, Union
+from typing import Callable, List, Optional, Union
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler
-from ...utils import deprecate, replace_example_docstring
+from ...utils import replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
@@ -161,11 +161,10 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
- prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
+ prior_callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ prior_callback_steps: int = 1,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
):
"""
Function invoked when calling the pipeline for generation.
@@ -227,23 +226,19 @@ def __call__(
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
- prior_callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep:
- int, callback_kwargs: Dict)`.
- prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
- list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
- the `._callback_tensor_inputs` attribute of your pipeline class.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
+ prior_callback (`Callable`, *optional*):
+ A function that will be called every `prior_callback_steps` steps during inference. The function will
+ be called with the following arguments: `prior_callback(step: int, timestep: int, latents:
+ torch.FloatTensor)`.
+ prior_callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
Examples:
@@ -251,22 +246,6 @@ def __call__(
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
"""
- prior_kwargs = {}
- if kwargs.get("prior_callback", None) is not None:
- prior_kwargs["callback"] = kwargs.pop("prior_callback")
- deprecate(
- "prior_callback",
- "1.0.0",
- "Passing `prior_callback` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`",
- )
- if kwargs.get("prior_callback_steps", None) is not None:
- deprecate(
- "prior_callback_steps",
- "1.0.0",
- "Passing `prior_callback_steps` as an input argument to `__call__` is deprecated, consider use `prior_callback_on_step_end`",
- )
- prior_kwargs["callback_steps"] = kwargs.pop("prior_callback_steps")
-
prior_outputs = self.prior_pipe(
prompt=prompt if prompt_embeds is None else None,
height=height,
@@ -282,9 +261,8 @@ def __call__(
latents=latents,
output_type="pt",
return_dict=False,
- callback_on_step_end=prior_callback_on_step_end,
- callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
- **prior_kwargs,
+ callback=prior_callback,
+ callback_steps=prior_callback_steps,
)
image_embeddings = prior_outputs[0]
@@ -298,9 +276,8 @@ def __call__(
generator=generator,
output_type=output_type,
return_dict=return_dict,
- callback_on_step_end=callback_on_step_end,
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
- **kwargs,
+ callback=callback,
+ callback_steps=callback_steps,
)
return outputs
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
index 8047f159677a..dba6d7bb06db 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
@@ -14,15 +14,18 @@
from dataclasses import dataclass
from math import ceil
-from typing import Callable, Dict, List, Optional, Union
+from typing import Callable, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
-from ...loaders import LoraLoaderMixin
from ...schedulers import DDPMWuerstchenScheduler
-from ...utils import BaseOutput, deprecate, logging, replace_example_docstring
+from ...utils import (
+ BaseOutput,
+ logging,
+ replace_example_docstring,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .modeling_wuerstchen_prior import WuerstchenPrior
@@ -62,7 +65,7 @@ class WuerstchenPriorPipelineOutput(BaseOutput):
image_embeddings: Union[torch.FloatTensor, np.ndarray]
-class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
+class WuerstchenPriorPipeline(DiffusionPipeline):
"""
Pipeline for generating image prior for Wuerstchen.
@@ -87,10 +90,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
Default resolution for multiple images generated.
"""
- unet_name = "prior"
- text_encoder_name = "text_encoder"
model_cpu_offload_seq = "text_encoder->prior"
- _callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"]
def __init__(
self,
@@ -261,18 +261,6 @@ def check_inputs(
In Case you want to provide explicit timesteps, please use the 'timesteps' argument."
)
- @property
- def guidance_scale(self):
- return self._guidance_scale
-
- @property
- def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
-
- @property
- def num_timesteps(self):
- return self._num_timesteps
-
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -291,9 +279,8 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pt",
return_dict: bool = True,
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
- **kwargs,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
):
"""
Function invoked when calling the pipeline for generation.
@@ -341,15 +328,12 @@ def __call__(
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
- callback_on_step_end (`Callable`, *optional*):
- A function that calls at the end of each denoising steps during the inference. The function is called
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
- `callback_on_step_end_tensor_inputs`.
- callback_on_step_end_tensor_inputs (`List`, *optional*):
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeline class.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
Examples:
@@ -359,32 +343,9 @@ def __call__(
generated image embeddings.
"""
- callback = kwargs.pop("callback", None)
- callback_steps = kwargs.pop("callback_steps", None)
-
- if callback is not None:
- deprecate(
- "callback",
- "1.0.0",
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
- if callback_steps is not None:
- deprecate(
- "callback_steps",
- "1.0.0",
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
- )
-
- if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
- ):
- raise ValueError(
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
- )
-
# 0. Define commonly used variables
device = self._execution_device
- self._guidance_scale = guidance_scale
+ do_classifier_free_guidance = guidance_scale > 1.0
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
@@ -399,7 +360,7 @@ def __call__(
else:
raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.")
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
if negative_prompt is not None and not isinstance(negative_prompt, list):
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt]
@@ -412,7 +373,7 @@ def __call__(
prompt,
negative_prompt,
num_inference_steps,
- self.do_classifier_free_guidance,
+ do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
@@ -422,7 +383,7 @@ def __call__(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
+ do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
@@ -455,22 +416,21 @@ def __call__(
latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler)
# 6. Run denoising loop
- self._num_timesteps = len(timesteps[:-1])
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
ratio = t.expand(latents.size(0)).to(dtype)
# 7. Denoise image embeddings
predicted_image_embedding = self.prior(
- torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
- r=torch.cat([ratio] * 2) if self.do_classifier_free_guidance else ratio,
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents,
+ r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio,
c=text_encoder_hidden_states,
)
# 8. Check for classifier free guidance and apply it
- if self.do_classifier_free_guidance:
+ if do_classifier_free_guidance:
predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2)
predicted_image_embedding = torch.lerp(
- predicted_image_embedding_uncond, predicted_image_embedding_text, self.guidance_scale
+ predicted_image_embedding_uncond, predicted_image_embedding_text, guidance_scale
)
# 9. Renoise latents to next timestep
@@ -481,18 +441,6 @@ def __call__(
generator=generator,
).prev_sample
- if callback_on_step_end is not None:
- callback_kwargs = {}
- for k in callback_on_step_end_tensor_inputs:
- callback_kwargs[k] = locals()[k]
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
-
- latents = callback_outputs.pop("latents", latents)
- text_encoder_hidden_states = callback_outputs.pop(
- "text_encoder_hidden_states", text_encoder_hidden_states
- )
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index 40c435dd5637..c6d1ee6d1006 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -38,8 +38,6 @@
_dummy_modules.update(get_objects_from_module(dummy_pt_objects))
else:
- _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
- _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
@@ -57,10 +55,11 @@
_import_structure["scheduling_ipndm"] = ["IPNDMScheduler"]
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
_import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"]
- _import_structure["scheduling_lcm"] = ["LCMScheduler"]
+ _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"]
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
_import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
+ _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"]
_import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
_import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
_import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"]
@@ -128,8 +127,6 @@
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
- from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
- from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler
@@ -147,10 +144,11 @@
from .scheduling_ipndm import IPNDMScheduler
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
- from .scheduling_lcm import LCMScheduler
+ from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_repaint import RePaintScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
+ from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_unclip import UnCLIPScheduler
from .scheduling_unipc_multistep import UniPCMultistepScheduler
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py
index d325cde7d9d4..5881874ab57a 100644
--- a/src/diffusers/schedulers/scheduling_ddim.py
+++ b/src/diffusers/schedulers/scheduling_ddim.py
@@ -208,7 +208,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py
index ea2d4945bd75..cc35046b1b6f 100644
--- a/src/diffusers/schedulers/scheduling_ddim_inverse.py
+++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py
@@ -204,7 +204,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py
index acc46242b401..8d698f67328e 100644
--- a/src/diffusers/schedulers/scheduling_ddim_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py
@@ -215,7 +215,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py
index c4a3eb43577c..bbc390a5d9ca 100644
--- a/src/diffusers/schedulers/scheduling_ddpm.py
+++ b/src/diffusers/schedulers/scheduling_ddpm.py
@@ -160,7 +160,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
index 6f2bebfb5a38..ca17ca5499e7 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
@@ -170,7 +170,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py
index bafa6d7f1b87..781efb12b18b 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py
@@ -211,15 +211,24 @@ def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
- timesteps: torch.FloatTensor,
+ timesteps: torch.IntTensor,
) -> torch.FloatTensor:
- device = original_samples.device
- dtype = original_samples.dtype
- alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
- timesteps.size(0), *[1 for _ in original_samples.shape[1:]]
- )
- noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise
- return noisy_samples.to(dtype=dtype)
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py
index 6aa994676577..a6afe744bd88 100644
--- a/src/diffusers/schedulers/scheduling_deis_multistep.py
+++ b/src/diffusers/schedulers/scheduling_deis_multistep.py
@@ -149,7 +149,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -291,7 +293,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
- log_sigma = np.log(np.maximum(sigma, 1e-10))
+ log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
@@ -323,20 +325,8 @@ def _sigma_to_alpha_sigma_t(self, sigma):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- # Hack to make sure that other schedulers which copy this function don't break
- # TODO: Add this logic to the other schedulers
- if hasattr(self.config, "sigma_min"):
- sigma_min = self.config.sigma_min
- else:
- sigma_min = None
-
- if hasattr(self.config, "sigma_max"):
- sigma_max = self.config.sigma_max
- else:
- sigma_max = None
-
- sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
- sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
index 4b638547b38a..6b1a43630fa6 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -117,17 +117,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
- euler_at_final (`bool`, defaults to `False`):
- Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
- richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
- steps, but sometimes may result in blurring.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
- use_lu_lambdas (`bool`, *optional*, defaults to `False`):
- Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
- the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
- `lambda(t)`.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
@@ -162,9 +154,7 @@ def __init__(
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
- euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
- use_lu_lambdas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
@@ -176,7 +166,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -266,12 +258,6 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
- elif self.config.use_lu_lambdas:
- lambdas = np.flip(log_sigmas.copy())
- lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
- sigmas = np.exp(lambdas)
- timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
- sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -327,7 +313,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
- log_sigma = np.log(np.maximum(sigma, 1e-10))
+ log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
@@ -358,20 +344,8 @@ def _sigma_to_alpha_sigma_t(self, sigma):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- # Hack to make sure that other schedulers which copy this function don't break
- # TODO: Add this logic to the other schedulers
- if hasattr(self.config, "sigma_min"):
- sigma_min = self.config.sigma_min
- else:
- sigma_min = None
-
- if hasattr(self.config, "sigma_max"):
- sigma_max = self.config.sigma_max
- else:
- sigma_max = None
-
- sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
- sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
@@ -380,19 +354,6 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
- def _convert_to_lu(self, in_lambdas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
- """Constructs the noise schedule of Lu et al. (2022)."""
-
- lambda_min: float = in_lambdas[-1].item()
- lambda_max: float = in_lambdas[0].item()
-
- rho = 1.0 # 1.0 is the value used in the paper
- ramp = np.linspace(0, 1, num_inference_steps)
- min_inv_rho = lambda_min ** (1 / rho)
- max_inv_rho = lambda_max ** (1 / rho)
- lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
- return lambdas
-
def convert_model_output(
self,
model_output: torch.FloatTensor,
@@ -826,9 +787,8 @@ def step(
if self.step_index is None:
self._init_step_index(timestep)
- # Improve numerical stability for small number of steps
- lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
- self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
+ lower_order_final = (
+ (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
index e762c0ec8bba..fa8f362bd3b5 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
@@ -117,10 +117,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
- euler_at_final (`bool`, defaults to `False`):
- Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail
- richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
- steps, but sometimes may result in blurring.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
@@ -158,7 +154,6 @@ def __init__(
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
- euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
@@ -171,7 +166,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -326,7 +323,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
- log_sigma = np.log(np.maximum(sigma, 1e-10))
+ log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
@@ -358,20 +355,8 @@ def _sigma_to_alpha_sigma_t(self, sigma):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- # Hack to make sure that other schedulers which copy this function don't break
- # TODO: Add this logic to the other schedulers
- if hasattr(self.config, "sigma_min"):
- sigma_min = self.config.sigma_min
- else:
- sigma_min = None
-
- if hasattr(self.config, "sigma_max"):
- sigma_max = self.config.sigma_max
- else:
- sigma_max = None
-
- sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
- sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
@@ -819,9 +804,8 @@ def step(
if self.step_index is None:
self._init_step_index(timestep)
- # Improve numerical stability for small number of steps
- lower_order_final = (self.step_index == len(self.timesteps) - 1) and (
- self.config.euler_at_final or (self.config.lower_order_final and len(self.timesteps) < 15)
+ lower_order_final = (
+ (self.step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
)
lower_order_second = (
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
index 12345a26bcf2..d39efbe724fb 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
@@ -182,7 +182,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -371,7 +373,7 @@ def t_fn(_sigma):
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
- log_sigma = np.log(np.maximum(sigma, 1e-10))
+ log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
index 2c0be3b842cc..bb7dc21e6fdb 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -159,7 +159,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -325,7 +327,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
- log_sigma = np.log(np.maximum(sigma, 1e-10))
+ log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
@@ -357,20 +359,8 @@ def _sigma_to_alpha_sigma_t(self, sigma):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- # Hack to make sure that other schedulers which copy this function don't break
- # TODO: Add this logic to the other schedulers
- if hasattr(self.config, "sigma_min"):
- sigma_min = self.config.sigma_min
- else:
- sigma_min = None
-
- if hasattr(self.config, "sigma_max"):
- sigma_max = self.config.sigma_max
- else:
- sigma_max = None
-
- sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
- sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
index 7c0dd803d91b..41ef3a3f2732 100644
--- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
@@ -145,7 +145,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py
index 53dc2ae15432..0875e1af3325 100644
--- a/src/diffusers/schedulers/scheduling_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_discrete.py
@@ -144,10 +144,7 @@ def __init__(
prediction_type: str = "epsilon",
interpolation_type: str = "linear",
use_karras_sigmas: Optional[bool] = False,
- sigma_min: Optional[float] = None,
- sigma_max: Optional[float] = None,
timestep_spacing: str = "linspace",
- timestep_type: str = "discrete", # can be "discrete" or "continuous"
steps_offset: int = 0,
):
if trained_betas is not None:
@@ -156,7 +153,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -167,22 +166,13 @@ def __init__(
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
- timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
-
- sigmas = torch.from_numpy(sigmas[::-1].copy()).to(dtype=torch.float32)
- timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas)
# setable values
self.num_inference_steps = None
-
- # TODO: Support the full EDM scalings for all prediction types and timestep types
- if timestep_type == "continuous" and prediction_type == "v_prediction":
- self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
- else:
- self.timesteps = timesteps
-
- self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
-
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps)
self.is_scale_input_called = False
self.use_karras_sigmas = use_karras_sigmas
@@ -280,20 +270,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
- sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
-
- # TODO: Support the full EDM scalings for all prediction types and timestep types
- if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
- self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
- else:
- self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
+ self.sigmas = torch.from_numpy(sigmas).to(device=device)
- self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
self._step_index = None
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
- log_sigma = np.log(np.maximum(sigma, 1e-10))
+ log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
@@ -318,20 +303,8 @@ def _sigma_to_t(self, sigma, log_sigmas):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- # Hack to make sure that other schedulers which copy this function don't break
- # TODO: Add this logic to the other schedulers
- if hasattr(self.config, "sigma_min"):
- sigma_min = self.config.sigma_min
- else:
- sigma_min = None
-
- if hasattr(self.config, "sigma_max"):
- sigma_max = self.config.sigma_max
- else:
- sigma_max = None
-
- sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
- sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
@@ -441,7 +414,7 @@ def step(
elif self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
elif self.config.prediction_type == "v_prediction":
- # denoised = model_output * c_out + input * c_skip
+ # * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py
index 460299cf2ec1..a5827bbc8610 100644
--- a/src/diffusers/schedulers/scheduling_heun_discrete.py
+++ b/src/diffusers/schedulers/scheduling_heun_discrete.py
@@ -131,7 +131,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine")
@@ -278,7 +280,7 @@ def set_timesteps(
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
- log_sigma = np.log(np.maximum(sigma, 1e-10))
+ log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
@@ -303,20 +305,8 @@ def _sigma_to_t(self, sigma, log_sigmas):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- # Hack to make sure that other schedulers which copy this function don't break
- # TODO: Add this logic to the other schedulers
- if hasattr(self.config, "sigma_min"):
- sigma_min = self.config.sigma_min
- else:
- sigma_min = None
-
- if hasattr(self.config, "sigma_max"):
- sigma_max = self.config.sigma_max
- else:
- sigma_max = None
-
- sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
- sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
index aae5a15abca2..a0137b83fda1 100644
--- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
+++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
@@ -127,7 +127,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -299,7 +301,7 @@ def set_timesteps(
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
- log_sigma = np.log(np.maximum(sigma, 1e-10))
+ log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
@@ -324,20 +326,8 @@ def _sigma_to_t(self, sigma, log_sigmas):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- # Hack to make sure that other schedulers which copy this function don't break
- # TODO: Add this logic to the other schedulers
- if hasattr(self.config, "sigma_min"):
- sigma_min = self.config.sigma_min
- else:
- sigma_min = None
-
- if hasattr(self.config, "sigma_max"):
- sigma_max = self.config.sigma_max
- else:
- sigma_max = None
-
- sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
- sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
index 3248520aa9a5..ddea57e8c167 100644
--- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
+++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
@@ -126,7 +126,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -310,7 +312,7 @@ def _init_step_index(self, timestep):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
- log_sigma = np.log(np.maximum(sigma, 1e-10))
+ log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
@@ -335,20 +337,8 @@ def _sigma_to_t(self, sigma, log_sigmas):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- # Hack to make sure that other schedulers which copy this function don't break
- # TODO: Add this logic to the other schedulers
- if hasattr(self.config, "sigma_min"):
- sigma_min = self.config.sigma_min
- else:
- sigma_min = None
-
- if hasattr(self.config, "sigma_max"):
- sigma_max = self.config.sigma_max
- else:
- sigma_max = None
-
- sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
- sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py
new file mode 100644
index 000000000000..462169b633de
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_karras_ve.py
@@ -0,0 +1,243 @@
+# Copyright 2023 NVIDIA and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from ..utils.torch_utils import randn_tensor
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+class KarrasVeOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Derivative of predicted original image sample (x_0).
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ derivative: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
+ """
+ A stochastic scheduler tailored to variance-expanding models.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+
+
+ For more details on the parameters, see [Appendix E](https://arxiv.org/abs/2206.00364). The grid search values used
+ to find the optimal `{s_noise, s_churn, s_min, s_max}` for a specific model are described in Table 5 of the paper.
+
+
+
+ Args:
+ sigma_min (`float`, defaults to 0.02):
+ The minimum noise magnitude.
+ sigma_max (`float`, defaults to 100):
+ The maximum noise magnitude.
+ s_noise (`float`, defaults to 1.007):
+ The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
+ 1.011].
+ s_churn (`float`, defaults to 80):
+ The parameter controlling the overall amount of stochasticity. A reasonable range is [0, 100].
+ s_min (`float`, defaults to 0.05):
+ The start value of the sigma range to add noise (enable stochasticity). A reasonable range is [0, 10].
+ s_max (`float`, defaults to 50):
+ The end value of the sigma range to add noise. A reasonable range is [0.2, 80].
+ """
+
+ order = 2
+
+ @register_to_config
+ def __init__(
+ self,
+ sigma_min: float = 0.02,
+ sigma_max: float = 100,
+ s_noise: float = 1.007,
+ s_churn: float = 80,
+ s_min: float = 0.05,
+ s_max: float = 50,
+ ):
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = sigma_max
+
+ # setable values
+ self.num_inference_steps: int = None
+ self.timesteps: np.IntTensor = None
+ self.schedule: torch.FloatTensor = None # sigma(t_i)
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.FloatTensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+ timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+ schedule = [
+ (
+ self.config.sigma_max**2
+ * (self.config.sigma_min**2 / self.config.sigma_max**2) ** (i / (num_inference_steps - 1))
+ )
+ for i in self.timesteps
+ ]
+ self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
+
+ def add_noise_to_input(
+ self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
+ ) -> Tuple[torch.FloatTensor, float]:
+ """
+ Explicit Langevin-like "churn" step of adding noise to the sample according to a `gamma_i ≥ 0` to reach a
+ higher noise level `sigma_hat = sigma_i + gamma_i*sigma_i`.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The input sample.
+ sigma (`float`):
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ """
+ if self.config.s_min <= sigma <= self.config.s_max:
+ gamma = min(self.config.s_churn / self.num_inference_steps, 2**0.5 - 1)
+ else:
+ gamma = 0
+
+ # sample eps ~ N(0, S_noise^2 * I)
+ eps = self.config.s_noise * randn_tensor(sample.shape, generator=generator).to(sample.device)
+ sigma_hat = sigma + gamma * sigma
+ sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
+
+ return sample_hat, sigma_hat
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[KarrasVeOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ sigma_hat (`float`):
+ sigma_prev (`float`):
+ sample_hat (`torch.FloatTensor`):
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] is returned,
+ otherwise a tuple is returned where the first element is the sample tensor.
+
+ """
+
+ pred_original_sample = sample_hat + sigma_hat * model_output
+ derivative = (sample_hat - pred_original_sample) / sigma_hat
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
+
+ if not return_dict:
+ return (sample_prev, derivative)
+
+ return KarrasVeOutput(
+ prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
+ )
+
+ def step_correct(
+ self,
+ model_output: torch.FloatTensor,
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: torch.FloatTensor,
+ sample_prev: torch.FloatTensor,
+ derivative: torch.FloatTensor,
+ return_dict: bool = True,
+ ) -> Union[KarrasVeOutput, Tuple]:
+ """
+ Corrects the predicted sample based on the `model_output` of the network.
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor`): TODO
+ sample_prev (`torch.FloatTensor`): TODO
+ derivative (`torch.FloatTensor`): TODO
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
+
+ Returns:
+ prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
+
+ """
+ pred_original_sample = sample_prev + sigma_prev * model_output
+ derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
+
+ if not return_dict:
+ return (sample_prev, derivative)
+
+ return KarrasVeOutput(
+ prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
+ )
+
+ def add_noise(self, original_samples, noise, timesteps):
+ raise NotImplementedError()
diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py
index 90e81c9b3c2c..9bee37d59ee1 100644
--- a/src/diffusers/schedulers/scheduling_lms_discrete.py
+++ b/src/diffusers/schedulers/scheduling_lms_discrete.py
@@ -146,7 +146,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -303,7 +305,7 @@ def _init_step_index(self, timestep):
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
- log_sigma = np.log(np.maximum(sigma, 1e-10))
+ log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py
index 4e5ef375a672..94bd6e51605e 100644
--- a/src/diffusers/schedulers/scheduling_pndm.py
+++ b/src/diffusers/schedulers/scheduling_pndm.py
@@ -132,7 +132,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py
index 9a7f15622234..733bd0a159fd 100644
--- a/src/diffusers/schedulers/scheduling_repaint.py
+++ b/src/diffusers/schedulers/scheduling_repaint.py
@@ -134,7 +134,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py
new file mode 100644
index 000000000000..b14bc867befa
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_sde_vp.py
@@ -0,0 +1,111 @@
+# Copyright 2023 Google Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
+
+import math
+from typing import Union
+
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils.torch_utils import randn_tensor
+from .scheduling_utils import SchedulerMixin
+
+
+class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `ScoreSdeVpScheduler` is a variance preserving stochastic differential equation (SDE) scheduler.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 2000):
+ The number of diffusion steps to train the model.
+ beta_min (`int`, defaults to 0.1):
+ beta_max (`int`, defaults to 20):
+ sampling_eps (`int`, defaults to 1e-3):
+ The end value of sampling where timesteps decrease progressively from 1 to epsilon.
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
+ self.sigmas = None
+ self.discrete_sigmas = None
+ self.timesteps = None
+
+ def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None):
+ """
+ Sets the continuous timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device)
+
+ def step_pred(self, score, x, t, generator=None):
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ score ():
+ x ():
+ t ():
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ """
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # TODO(Patrick) better comments + non-PyTorch
+ # postprocess model score
+ log_mean_coeff = (
+ -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
+ )
+ std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
+ std = std.flatten()
+ while len(std.shape) < len(score.shape):
+ std = std.unsqueeze(-1)
+ score = -score / std
+
+ # compute
+ dt = -1.0 / len(self.timesteps)
+
+ beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
+ beta_t = beta_t.flatten()
+ while len(beta_t.shape) < len(x.shape):
+ beta_t = beta_t.unsqueeze(-1)
+ drift = -0.5 * beta_t * x
+
+ diffusion = torch.sqrt(beta_t)
+ drift = drift - diffusion**2 * score
+ x_mean = x + drift * dt
+
+ # add noise
+ noise = randn_tensor(x.shape, layout=x.layout, generator=generator, device=x.device, dtype=x.dtype)
+ x = x_mean + diffusion * math.sqrt(-dt) * noise
+
+ return x, x_mean
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py
index d778f37ec059..741b03b6d3a2 100644
--- a/src/diffusers/schedulers/scheduling_unipc_multistep.py
+++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -162,7 +162,9 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -305,7 +307,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
- log_sigma = np.log(np.maximum(sigma, 1e-10))
+ log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
@@ -337,20 +339,8 @@ def _sigma_to_alpha_sigma_t(self, sigma):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
- # Hack to make sure that other schedulers which copy this function don't break
- # TODO: Add this logic to the other schedulers
- if hasattr(self.config, "sigma_min"):
- sigma_min = self.config.sigma_min
- else:
- sigma_min = None
-
- if hasattr(self.config, "sigma_max"):
- sigma_max = self.config.sigma_max
- else:
- sigma_max = None
-
- sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
- sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+ sigma_min: float = in_sigmas[-1].item()
+ sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index c1385d584724..b4d6bdab33eb 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -89,7 +89,6 @@
from .outputs import BaseOutput
from .peft_utils import (
check_peft_version,
- delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
recurse_remove_peft_layers,
diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py
index 608a751fb8d6..3023cb476fe0 100644
--- a/src/diffusers/utils/constants.py
+++ b/src/diffusers/utils/constants.py
@@ -17,15 +17,12 @@
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
from packaging import version
-from ..dependency_versions_check import dep_version_check
-from .import_utils import ENV_VARS_TRUE_VALUES, is_peft_available, is_transformers_available
+from .import_utils import is_peft_available, is_transformers_available
default_cache_path = HUGGINGFACE_HUB_CACHE
-MIN_PEFT_VERSION = "0.6.0"
-MIN_TRANSFORMERS_VERSION = "4.34.0"
-_CHECK_PEFT = os.environ.get("_CHECK_PEFT", "1") in ENV_VARS_TRUE_VALUES
+MIN_PEFT_VERSION = "0.5.0"
CONFIG_NAME = "config.json"
@@ -43,15 +40,12 @@
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
# available.
-# For PEFT it is has to be greater than or equal to 0.6.0 and for transformers it has to be greater than or equal to 4.34.0.
+# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
-) >= version.parse(MIN_PEFT_VERSION)
+) > version.parse(MIN_PEFT_VERSION)
_required_transformers_version = is_transformers_available() and version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
-) >= version.parse(MIN_TRANSFORMERS_VERSION)
+) > version.parse("4.33")
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
-
-if USE_PEFT_BACKEND and _CHECK_PEFT:
- dep_version_check("peft")
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index c19b15f2f483..8e95dde52caf 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -32,21 +32,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
-
class AutoencoderTiny(metaclass=DummyObject):
_backends = ["torch"]
@@ -62,21 +47,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class ConsistencyDecoderVAE(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
-
class ControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -92,21 +62,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class Kandinsky3UNet(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
-
class ModelMixin(metaclass=DummyObject):
_backends = ["torch"]
@@ -122,21 +77,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class MotionAdapter(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
-
class MultiAdapter(metaclass=DummyObject):
_backends = ["torch"]
@@ -272,36 +212,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class UNetMotionModel(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
-
-class UNetSpatioTemporalConditionModel(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
-
class VQModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -915,21 +825,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class LCMScheduler(metaclass=DummyObject):
- _backends = ["torch"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch"])
-
-
class PNDMScheduler(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index b039cdc72ab6..d831cc49b495 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -32,21 +32,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffPipeline(metaclass=DummyObject):
- _backends = ["torch", "transformers"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch", "transformers"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
-
class AudioLDM2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -242,36 +227,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class Kandinsky3Img2ImgPipeline(metaclass=DummyObject):
- _backends = ["torch", "transformers"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch", "transformers"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
-
-class Kandinsky3Pipeline(metaclass=DummyObject):
- _backends = ["torch", "transformers"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch", "transformers"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
-
class KandinskyCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -527,36 +482,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LatentConsistencyModelImg2ImgPipeline(metaclass=DummyObject):
- _backends = ["torch", "transformers"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch", "transformers"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
-
-class LatentConsistencyModelPipeline(metaclass=DummyObject):
- _backends = ["torch", "transformers"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch", "transformers"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
-
class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -602,21 +527,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class PixArtAlphaPipeline(metaclass=DummyObject):
- _backends = ["torch", "transformers"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch", "transformers"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
-
class SemanticStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1172,21 +1082,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class StableVideoDiffusionPipeline(metaclass=DummyObject):
- _backends = ["torch", "transformers"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch", "transformers"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
-
class TextToVideoSDPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1217,21 +1112,6 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class TextToVideoZeroSDXLPipeline(metaclass=DummyObject):
- _backends = ["torch", "transformers"]
-
- def __init__(self, *args, **kwargs):
- requires_backends(self, ["torch", "transformers"])
-
- @classmethod
- def from_config(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
- @classmethod
- def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["torch", "transformers"])
-
-
class UnCLIPImageVariationPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py
index d668cb40c631..5b0952f0b514 100644
--- a/src/diffusers/utils/dynamic_modules_utils.py
+++ b/src/diffusers/utils/dynamic_modules_utils.py
@@ -87,9 +87,9 @@ def get_relative_imports(module_file):
content = f.read()
# Imports of the form `import .xxx`
- relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
+ relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
# Imports of the form `from .xxx import yyy`
- relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
+ relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
# Unique-ify
return list(set(relative_imports))
@@ -131,9 +131,9 @@ def check_imports(filename):
content = f.read()
# Imports of the form `import xxx`
- imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
+ imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
# Imports of the form `from xxx import yyy`
- imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
+ imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
# Only keep the top-level module
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
diff --git a/src/diffusers/utils/export_utils.py b/src/diffusers/utils/export_utils.py
index 45aece18b8fd..f7744f9d63eb 100644
--- a/src/diffusers/utils/export_utils.py
+++ b/src/diffusers/utils/export_utils.py
@@ -3,7 +3,7 @@
import struct
import tempfile
from contextlib import contextmanager
-from typing import List, Union
+from typing import List
import numpy as np
import PIL.Image
@@ -115,9 +115,7 @@ def export_to_obj(mesh, output_obj_path: str = None):
f.writelines("\n".join(combined_data))
-def export_to_video(
- video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
-) -> str:
+def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
if is_opencv_available():
import cv2
else:
@@ -125,12 +123,9 @@ def export_to_video(
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
- if isinstance(video_frames[0], PIL.Image.Image):
- video_frames = [np.array(frame) for frame in video_frames]
-
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
h, w, c = video_frames[0].shape
- video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
+ video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h))
for i in range(len(video_frames)):
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
video_writer.write(img)
diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py
index 6050f314c008..4ccc57cd69d5 100644
--- a/src/diffusers/utils/logging.py
+++ b/src/diffusers/utils/logging.py
@@ -28,7 +28,7 @@
WARN, # NOQA
WARNING, # NOQA
)
-from typing import Dict, Optional
+from typing import Optional
from tqdm import auto as tqdm_lib
@@ -49,7 +49,7 @@
_tqdm_active = True
-def _get_default_logging_level() -> int:
+def _get_default_logging_level():
"""
If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to `_default_log_level`
@@ -104,7 +104,7 @@ def _reset_library_root_logger() -> None:
_default_handler = None
-def get_log_levels_dict() -> Dict[str, int]:
+def get_log_levels_dict():
return log_levels
@@ -161,22 +161,22 @@ def set_verbosity(verbosity: int) -> None:
_get_library_root_logger().setLevel(verbosity)
-def set_verbosity_info() -> None:
+def set_verbosity_info():
"""Set the verbosity to the `INFO` level."""
return set_verbosity(INFO)
-def set_verbosity_warning() -> None:
+def set_verbosity_warning():
"""Set the verbosity to the `WARNING` level."""
return set_verbosity(WARNING)
-def set_verbosity_debug() -> None:
+def set_verbosity_debug():
"""Set the verbosity to the `DEBUG` level."""
return set_verbosity(DEBUG)
-def set_verbosity_error() -> None:
+def set_verbosity_error():
"""Set the verbosity to the `ERROR` level."""
return set_verbosity(ERROR)
@@ -263,7 +263,7 @@ def reset_format() -> None:
handler.setFormatter(None)
-def warning_advice(self, *args, **kwargs) -> None:
+def warning_advice(self, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed
@@ -327,13 +327,13 @@ def is_progress_bar_enabled() -> bool:
return bool(_tqdm_active)
-def enable_progress_bar() -> None:
+def enable_progress_bar():
"""Enable tqdm progress bar."""
global _tqdm_active
_tqdm_active = True
-def disable_progress_bar() -> None:
+def disable_progress_bar():
"""Disable tqdm progress bar."""
global _tqdm_active
_tqdm_active = False
diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py
index 01a297361955..802c699eb9cc 100644
--- a/src/diffusers/utils/outputs.py
+++ b/src/diffusers/utils/outputs.py
@@ -24,7 +24,7 @@
from .import_utils import is_torch_available
-def is_tensor(x) -> bool:
+def is_tensor(x):
"""
Tests if `x` is a `torch.Tensor` or `np.ndarray`.
"""
@@ -51,22 +51,7 @@ class BaseOutput(OrderedDict):
"""
- def __init_subclass__(cls) -> None:
- """Register subclasses as pytree nodes.
-
- This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
- `static_graph=True` with modules that output `ModelOutput` subclasses.
- """
- if is_torch_available():
- import torch.utils._pytree
-
- torch.utils._pytree._register_pytree_node(
- cls,
- torch.utils._pytree._dict_flatten,
- lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
- )
-
- def __post_init__(self) -> None:
+ def __post_init__(self):
class_fields = fields(self)
# Safety and consistency checks
@@ -97,14 +82,14 @@ def pop(self, *args, **kwargs):
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
- def __getitem__(self, k: Any) -> Any:
+ def __getitem__(self, k):
if isinstance(k, str):
inner_dict = dict(self.items())
return inner_dict[k]
else:
return self.to_tuple()[k]
- def __setattr__(self, name: Any, value: Any) -> None:
+ def __setattr__(self, name, value):
if name in self.keys() and value is not None:
# Don't call self.__setitem__ to avoid recursion errors
super().__setitem__(name, value)
@@ -123,7 +108,7 @@ def __reduce__(self):
args = tuple(getattr(self, field.name) for field in fields(self))
return callable, args, *remaining
- def to_tuple(self) -> Tuple[Any, ...]:
+ def to_tuple(self) -> Tuple[Any]:
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py
index c77efc28f62a..efc977518b14 100644
--- a/src/diffusers/utils/peft_utils.py
+++ b/src/diffusers/utils/peft_utils.py
@@ -16,84 +16,61 @@
"""
import collections
import importlib
-from typing import Optional
from packaging import version
from .import_utils import is_peft_available, is_torch_available
-if is_torch_available():
- import torch
-
-
def recurse_remove_peft_layers(model):
+ if is_torch_available():
+ import torch
+
r"""
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
"""
- from peft.tuners.tuners_utils import BaseTunerLayer
+ from peft.tuners.lora import LoraLayer
+
+ for name, module in model.named_children():
+ if len(list(module.children())) > 0:
+ ## compound module, go inside it
+ recurse_remove_peft_layers(module)
+
+ module_replaced = False
+
+ if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
+ new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
+ module.weight.device
+ )
+ new_module.weight = module.weight
+ if module.bias is not None:
+ new_module.bias = module.bias
+
+ module_replaced = True
+ elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
+ new_module = torch.nn.Conv2d(
+ module.in_channels,
+ module.out_channels,
+ module.kernel_size,
+ module.stride,
+ module.padding,
+ module.dilation,
+ module.groups,
+ ).to(module.weight.device)
+
+ new_module.weight = module.weight
+ if module.bias is not None:
+ new_module.bias = module.bias
+
+ module_replaced = True
+
+ if module_replaced:
+ setattr(model, name, new_module)
+ del module
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
- has_base_layer_pattern = False
- for module in model.modules():
- if isinstance(module, BaseTunerLayer):
- has_base_layer_pattern = hasattr(module, "base_layer")
- break
-
- if has_base_layer_pattern:
- from peft.utils import _get_submodules
-
- key_list = [key for key, _ in model.named_modules() if "lora" not in key]
- for key in key_list:
- try:
- parent, target, target_name = _get_submodules(model, key)
- except AttributeError:
- continue
- if hasattr(target, "base_layer"):
- setattr(parent, target_name, target.get_base_layer())
- else:
- # This is for backwards compatibility with PEFT <= 0.6.2.
- # TODO can be removed once that PEFT version is no longer supported.
- from peft.tuners.lora import LoraLayer
-
- for name, module in model.named_children():
- if len(list(module.children())) > 0:
- ## compound module, go inside it
- recurse_remove_peft_layers(module)
-
- module_replaced = False
-
- if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
- new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
- module.weight.device
- )
- new_module.weight = module.weight
- if module.bias is not None:
- new_module.bias = module.bias
-
- module_replaced = True
- elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
- new_module = torch.nn.Conv2d(
- module.in_channels,
- module.out_channels,
- module.kernel_size,
- module.stride,
- module.padding,
- module.dilation,
- module.groups,
- ).to(module.weight.device)
-
- new_module.weight = module.weight
- if module.bias is not None:
- new_module.bias = module.bias
-
- module_replaced = True
-
- if module_replaced:
- setattr(model, name, new_module)
- del module
-
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
return model
@@ -114,28 +91,21 @@ def scale_lora_layers(model, weight):
module.scale_layer(weight)
-def unscale_lora_layers(model, weight: Optional[float] = None):
+def unscale_lora_layers(model):
"""
Removes the previously passed weight given to the LoRA layers of the model.
Args:
model (`torch.nn.Module`):
The model to scale.
- weight (`float`, *optional*):
- The weight to be given to the LoRA layers. If no scale is passed the scale of the lora layer will be
- re-initialized to the correct value. If 0.0 is passed, we will re-initialize the scale with the correct
- value.
+ weight (`float`):
+ The weight to be given to the LoRA layers.
"""
from peft.tuners.tuners_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
- if weight is not None and weight != 0:
- module.unscale_layer(weight)
- elif weight is not None and weight == 0:
- for adapter_name in module.active_adapters:
- # if weight == 0 unscale should re-set the scale to the original value.
- module.set_scale(adapter_name, 1.0)
+ module.unscale_layer()
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
@@ -151,7 +121,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
- if network_alpha_dict is not None and len(network_alpha_dict) > 0:
+ if network_alpha_dict is not None:
if len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
@@ -202,28 +172,6 @@ def set_adapter_layers(model, enabled=True):
module.disable_adapters = not enabled
-def delete_adapter_layers(model, adapter_name):
- from peft.tuners.tuners_utils import BaseTunerLayer
-
- for module in model.modules():
- if isinstance(module, BaseTunerLayer):
- if hasattr(module, "delete_adapter"):
- module.delete_adapter(adapter_name)
- else:
- raise ValueError(
- "The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
- )
-
- # For transformers integration - we need to pop the adapter from the config
- if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"):
- model.peft_config.pop(adapter_name, None)
- # In case all adapters are deleted, we need to delete the config
- # and make sure to set the flag to False
- if len(model.peft_config) == 0:
- del model.peft_config
- model._hf_peft_config_loaded = None
-
-
def set_weights_and_activate_adapters(model, adapter_names, weights):
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -236,7 +184,7 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
- module.set_scale(adapter_name, weight)
+ module.scale_layer(weight)
# set multiple active adapters
for module in model.modules():
diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py
index 00bc75f41be3..7955ccb01d85 100644
--- a/src/diffusers/utils/torch_utils.py
+++ b/src/diffusers/utils/torch_utils.py
@@ -82,14 +82,14 @@ def randn_tensor(
return latents
-def is_compiled_module(module) -> bool:
+def is_compiled_module(module):
"""Check whether the module was compiled with torch.compile()"""
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
return False
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
-def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tensor:
+def fourier_filter(x_in, threshold, scale):
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
This version of the method comes from here:
diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py
index 19505a1d906d..047cdddfa95a 100644
--- a/tests/lora/test_lora_layers_old_backend.py
+++ b/tests/lora/test_lora_layers_old_backend.py
@@ -41,7 +41,7 @@
UNet2DConditionModel,
UNet3DConditionModel,
)
-from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
+from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules
from diffusers.models.attention_processor import (
Attention,
AttnProcessor,
@@ -51,7 +51,6 @@
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
-from diffusers.models.lora import PatchedLoraProjection, text_encoder_attn_modules
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
deprecate_after_peft_backend,
@@ -246,7 +245,6 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
@@ -758,7 +756,6 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
@@ -868,8 +865,6 @@ def get_dummy_components(self):
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
- "image_encoder": None,
- "feature_extractor": None,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py
index 48ae5d197273..198ff53340c8 100644
--- a/tests/lora/test_lora_layers_peft.py
+++ b/tests/lora/test_lora_layers_peft.py
@@ -22,25 +22,25 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import (
AutoencoderKL,
- AutoPipelineForImage2Image,
ControlNetModel,
DDIMScheduler,
DiffusionPipeline,
EulerDiscreteScheduler,
- LCMScheduler,
StableDiffusionPipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
-from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
+from diffusers.models.attention_processor import (
+ LoRAAttnProcessor,
+ LoRAAttnProcessor2_0,
+)
from diffusers.utils.import_utils import is_accelerate_available, is_peft_available
from diffusers.utils.testing_utils import (
floats_tensor,
@@ -106,12 +106,10 @@ class PeftLoraLoaderMixinTests:
unet_kwargs = None
vae_kwargs = None
- def get_dummy_components(self, scheduler_cls=None):
- scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
-
+ def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(**self.unet_kwargs)
- scheduler = scheduler_cls(**self.scheduler_kwargs)
+ scheduler = self.scheduler_cls(**self.scheduler_kwargs)
torch.manual_seed(0)
vae = AutoencoderKL(**self.vae_kwargs)
text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2")
@@ -140,8 +138,6 @@ def get_dummy_components(self, scheduler_cls=None):
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
- "image_encoder": None,
- "feature_extractor": None,
}
else:
pipeline_components = {
@@ -152,7 +148,6 @@ def get_dummy_components(self, scheduler_cls=None):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
@@ -204,896 +199,673 @@ def test_simple_inference(self):
"""
Tests a simple inference and makes sure it works as expected
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
+ components, _, _, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs()
- output_no_lora = pipe(**inputs).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ _, _, inputs = self.get_dummy_inputs()
+ output_no_lora = pipe(**inputs).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
def test_simple_inference_with_text_lora(self):
"""
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
+ pipe.text_encoder.add_adapter(text_lora_config)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
-
- output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ )
+
def test_simple_inference_with_text_lora_and_scale(self):
"""
Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ components, _, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe.text_encoder.add_adapter(text_lora_config)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- output_lora_scale = pipe(
- **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
- ).images
- self.assertTrue(
- not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
- "Lora + scale should change the output",
- )
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ )
- output_lora_0_scale = pipe(
- **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
- ).images
- self.assertTrue(
- np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
- "Lora + 0 scale should lead to same result as no LoRA",
- )
+ output_lora_scale = pipe(
+ **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
+ ).images
+ self.assertTrue(
+ not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
+
+ output_lora_0_scale = pipe(
+ **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
+ ).images
+ self.assertTrue(
+ np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
+ "Lora + 0 scale should lead to same result as no LoRA",
+ )
def test_simple_inference_with_text_lora_fused(self):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.text_encoder.add_adapter(text_lora_config)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe.fuse_lora()
+ # Fusing should still keep the LoRA layers
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- pipe.fuse_lora()
- # Fusing should still keep the LoRA layers
+ if self.has_two_text_encoders:
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- if self.has_two_text_encoders:
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
-
- ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertFalse(
- np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
- )
+ ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertFalse(
+ np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
+ )
def test_simple_inference_with_text_lora_unloaded(self):
"""
Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ pipe.text_encoder.add_adapter(text_lora_config)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- pipe.text_encoder.add_adapter(text_lora_config)
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe.unload_lora_weights()
+ # unloading should remove the LoRA layers
+ self.assertFalse(
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
+ )
- pipe.unload_lora_weights()
- # unloading should remove the LoRA layers
+ if self.has_two_text_encoders:
self.assertFalse(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2"
)
- if self.has_two_text_encoders:
- self.assertFalse(
- self.check_if_lora_correctly_set(pipe.text_encoder_2),
- "Lora not correctly unloaded in text encoder 2",
- )
-
- ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(
- np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
- "Fused lora should change the output",
- )
+ ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(
+ np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
+ )
def test_simple_inference_with_text_lora_save_load(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.text_encoder.add_adapter(text_lora_config)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
+ if self.has_two_text_encoders:
+ text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
- with tempfile.TemporaryDirectory() as tmpdirname:
- text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
- if self.has_two_text_encoders:
- text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname,
+ text_encoder_lora_layers=text_encoder_state_dict,
+ text_encoder_2_lora_layers=text_encoder_2_state_dict,
+ safe_serialization=False,
+ )
+ else:
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname,
+ text_encoder_lora_layers=text_encoder_state_dict,
+ safe_serialization=False,
+ )
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname,
- text_encoder_lora_layers=text_encoder_state_dict,
- text_encoder_2_lora_layers=text_encoder_2_state_dict,
- safe_serialization=False,
- )
- else:
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname,
- text_encoder_lora_layers=text_encoder_state_dict,
- safe_serialization=False,
- )
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
+ pipe.unload_lora_weights()
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
- pipe.unload_lora_weights()
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
+ images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
+ if self.has_two_text_encoders:
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- if self.has_two_text_encoders:
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
-
- self.assertTrue(
- np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
- )
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
def test_simple_inference_save_pretrained(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.text_encoder.add_adapter(text_lora_config)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pipe.save_pretrained(tmpdirname)
- with tempfile.TemporaryDirectory() as tmpdirname:
- pipe.save_pretrained(tmpdirname)
+ pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
+ pipe_from_pretrained.to(self.torch_device)
- pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
- pipe_from_pretrained.to(self.torch_device)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
+ "Lora not correctly set in text encoder",
+ )
+ if self.has_two_text_encoders:
self.assertTrue(
- self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
- "Lora not correctly set in text encoder",
+ self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
+ "Lora not correctly set in text encoder 2",
)
- if self.has_two_text_encoders:
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
- "Lora not correctly set in text encoder 2",
- )
-
- images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images
+ images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(
- np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
- )
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
def test_simple_inference_with_text_unet_lora_save_load(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.unet.add_adapter(unet_lora_config)
- pipe.text_encoder.add_adapter(text_lora_config)
- pipe.unet.add_adapter(unet_lora_config)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
+ unet_state_dict = get_peft_model_state_dict(pipe.unet)
if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
+
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname,
+ text_encoder_lora_layers=text_encoder_state_dict,
+ text_encoder_2_lora_layers=text_encoder_2_state_dict,
+ unet_lora_layers=unet_state_dict,
+ safe_serialization=False,
+ )
+ else:
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname,
+ text_encoder_lora_layers=text_encoder_state_dict,
+ unet_lora_layers=unet_state_dict,
+ safe_serialization=False,
)
- images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
- unet_state_dict = get_peft_model_state_dict(pipe.unet)
- if self.has_two_text_encoders:
- text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
-
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname,
- text_encoder_lora_layers=text_encoder_state_dict,
- text_encoder_2_lora_layers=text_encoder_2_state_dict,
- unet_lora_layers=unet_state_dict,
- safe_serialization=False,
- )
- else:
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname,
- text_encoder_lora_layers=text_encoder_state_dict,
- unet_lora_layers=unet_state_dict,
- safe_serialization=False,
- )
-
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
- pipe.unload_lora_weights()
-
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
-
- images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
+ pipe.unload_lora_weights()
- if self.has_two_text_encoders:
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
+ images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+
+ if self.has_two_text_encoders:
self.assertTrue(
- np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
+
def test_simple_inference_with_text_unet_lora_and_scale(self):
"""
Tests a simple inference with lora attached on the text encoder + Unet + scale argument
and makes sure it works as expected
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
-
- pipe.text_encoder.add_adapter(text_lora_config)
- pipe.unet.add_adapter(unet_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
- )
+ pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.unet.add_adapter(unet_lora_config)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- output_lora_scale = pipe(
- **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
- ).images
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
- "Lora + scale should change the output",
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- output_lora_0_scale = pipe(
- **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
- ).images
- self.assertTrue(
- np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
- "Lora + 0 scale should lead to same result as no LoRA",
- )
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ )
- self.assertTrue(
- pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
- "The scaling parameter has not been correctly restored!",
- )
+ output_lora_scale = pipe(
+ **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
+ ).images
+ self.assertTrue(
+ not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
+
+ output_lora_0_scale = pipe(
+ **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
+ ).images
+ self.assertTrue(
+ np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
+ "Lora + 0 scale should lead to same result as no LoRA",
+ )
+
+ self.assertTrue(
+ pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
+ "The scaling parameter has not been correctly restored!",
+ )
def test_simple_inference_with_text_lora_unet_fused(self):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.unet.add_adapter(unet_lora_config)
- pipe.text_encoder.add_adapter(text_lora_config)
- pipe.unet.add_adapter(unet_lora_config)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe.fuse_lora()
+ # Fusing should still keep the LoRA layers
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet")
- pipe.fuse_lora()
- # Fusing should still keep the LoRA layers
+ if self.has_two_text_encoders:
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet")
- if self.has_two_text_encoders:
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
-
- ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertFalse(
- np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
- )
+ ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertFalse(
+ np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
+ )
def test_simple_inference_with_text_unet_lora_unloaded(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
-
- pipe.text_encoder.add_adapter(text_lora_config)
- pipe.unet.add_adapter(unet_lora_config)
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+
+ pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.unet.add_adapter(unet_lora_config)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe.unload_lora_weights()
+ # unloading should remove the LoRA layers
+ self.assertFalse(
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
+ )
+ self.assertFalse(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly unloaded in Unet")
- pipe.unload_lora_weights()
- # unloading should remove the LoRA layers
+ if self.has_two_text_encoders:
self.assertFalse(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2"
)
- self.assertFalse(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly unloaded in Unet")
-
- if self.has_two_text_encoders:
- self.assertFalse(
- self.check_if_lora_correctly_set(pipe.text_encoder_2),
- "Lora not correctly unloaded in text encoder 2",
- )
- ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(
- np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
- "Fused lora should change the output",
- )
+ ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(
+ np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
+ )
def test_simple_inference_with_text_unet_lora_unfused(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.unet.add_adapter(unet_lora_config)
- pipe.text_encoder.add_adapter(text_lora_config)
- pipe.unet.add_adapter(unet_lora_config)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe.fuse_lora()
- pipe.fuse_lora()
+ output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ pipe.unfuse_lora()
- pipe.unfuse_lora()
+ output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ # unloading should remove the LoRA layers
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Unfuse should still keep LoRA layers")
- output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- # unloading should remove the LoRA layers
+ if self.has_two_text_encoders:
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Unfuse should still keep LoRA layers")
-
- if self.has_two_text_encoders:
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
- )
- # Fuse and unfuse should lead to the same results
- self.assertTrue(
- np.allclose(output_fused_lora, output_unfused_lora, atol=1e-3, rtol=1e-3),
- "Fused lora should change the output",
- )
+ # Fuse and unfuse should lead to the same results
+ self.assertTrue(
+ np.allclose(output_fused_lora, output_unfused_lora, atol=1e-3, rtol=1e-3),
+ "Fused lora should change the output",
+ )
def test_simple_inference_with_text_unet_multi_adapter(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
-
- pipe.unet.add_adapter(unet_lora_config, "adapter-1")
- pipe.unet.add_adapter(unet_lora_config, "adapter-2")
-
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
-
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
-
- pipe.set_adapters("adapter-1")
-
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- pipe.set_adapters("adapter-2")
- output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- pipe.set_adapters(["adapter-1", "adapter-2"])
-
- output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- # Fuse and unfuse should lead to the same results
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
-
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 1 and mixed adapters should give different results",
- )
-
- self.assertFalse(
- np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 2 and mixed adapters should give different results",
- )
-
- pipe.disable_lora()
-
- output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- self.assertTrue(
- np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
-
- def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self):
- """
- Tests a simple inference with lora attached to text encoder and unet, attaches
- multiple adapters and set/delete them
- """
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
-
- pipe.unet.add_adapter(unet_lora_config, "adapter-1")
- pipe.unet.add_adapter(unet_lora_config, "adapter-2")
-
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
-
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
-
- pipe.set_adapters("adapter-1")
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- pipe.set_adapters("adapter-2")
- output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- pipe.set_adapters(["adapter-1", "adapter-2"])
+ pipe.unet.add_adapter(unet_lora_config, "adapter-1")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-2")
- output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
-
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 1 and mixed adapters should give different results",
- )
-
- self.assertFalse(
- np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 2 and mixed adapters should give different results",
- )
-
- pipe.delete_adapters("adapter-1")
- output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- self.assertTrue(
- np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
-
- pipe.delete_adapters("adapter-2")
- output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- self.assertTrue(
- np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
-
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
-
- pipe.unet.add_adapter(unet_lora_config, "adapter-1")
- pipe.unet.add_adapter(unet_lora_config, "adapter-2")
-
- pipe.set_adapters(["adapter-1", "adapter-2"])
- pipe.delete_adapters(["adapter-1", "adapter-2"])
-
- output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- self.assertTrue(
- np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
-
- def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
- """
- Tests a simple inference with lora attached to text encoder and unet, attaches
- multiple adapters and set them
- """
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
-
- pipe.unet.add_adapter(unet_lora_config, "adapter-1")
- pipe.unet.add_adapter(unet_lora_config, "adapter-2")
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
-
- pipe.set_adapters("adapter-1")
-
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- pipe.set_adapters("adapter-2")
- output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
+ pipe.set_adapters("adapter-1")
- pipe.set_adapters(["adapter-1", "adapter-2"])
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
- output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
+ pipe.set_adapters("adapter-2")
+ output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
- # Fuse and unfuse should lead to the same results
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
+ pipe.set_adapters(["adapter-1", "adapter-2"])
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 1 and mixed adapters should give different results",
- )
+ output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertFalse(
- np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 2 and mixed adapters should give different results",
- )
+ # Fuse and unfuse should lead to the same results
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
- pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
- output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and mixed adapters should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Weighted adapter and mixed adapter should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 2 and mixed adapters should give different results",
+ )
- pipe.disable_lora()
+ pipe.disable_lora()
- output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
+ output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(
- np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
def test_lora_fuse_nan(self):
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.unet.add_adapter(unet_lora_config, "adapter-1")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-1")
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- # corrupt one LoRA weight with `inf` values
- with torch.no_grad():
- pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
- "inf"
- )
+ # corrupt one LoRA weight with `inf` values
+ with torch.no_grad():
+ pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
+ "inf"
+ )
- # with `safe_fusing=True` we should see an Error
- with self.assertRaises(ValueError):
- pipe.fuse_lora(safe_fusing=True)
+ # with `safe_fusing=True` we should see an Error
+ with self.assertRaises(ValueError):
+ pipe.fuse_lora(safe_fusing=True)
- # without we should not see an error, but every image will be black
- pipe.fuse_lora(safe_fusing=False)
+ # without we should not see an error, but every image will be black
+ pipe.fuse_lora(safe_fusing=False)
- out = pipe("test", num_inference_steps=2, output_type="np").images
+ out = pipe("test", num_inference_steps=2, output_type="np").images
- self.assertTrue(np.isnan(out).all())
+ self.assertTrue(np.isnan(out).all())
def test_get_adapters(self):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.unet.add_adapter(unet_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-1")
- adapter_names = pipe.get_active_adapters()
- self.assertListEqual(adapter_names, ["adapter-1"])
+ adapter_names = pipe.get_active_adapters()
+ self.assertListEqual(adapter_names, ["adapter-1"])
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- pipe.unet.add_adapter(unet_lora_config, "adapter-2")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-2")
- adapter_names = pipe.get_active_adapters()
- self.assertListEqual(adapter_names, ["adapter-2"])
+ adapter_names = pipe.get_active_adapters()
+ self.assertListEqual(adapter_names, ["adapter-2"])
- pipe.set_adapters(["adapter-1", "adapter-2"])
- self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
def test_get_list_adapters(self):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.unet.add_adapter(unet_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-1")
- adapter_names = pipe.get_list_adapters()
- self.assertDictEqual(adapter_names, {"text_encoder": ["adapter-1"], "unet": ["adapter-1"]})
+ adapter_names = pipe.get_list_adapters()
+ self.assertDictEqual(adapter_names, {"text_encoder": ["adapter-1"], "unet": ["adapter-1"]})
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- pipe.unet.add_adapter(unet_lora_config, "adapter-2")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-2")
- adapter_names = pipe.get_list_adapters()
- self.assertDictEqual(
- adapter_names, {"text_encoder": ["adapter-1", "adapter-2"], "unet": ["adapter-1", "adapter-2"]}
- )
+ adapter_names = pipe.get_list_adapters()
+ self.assertDictEqual(
+ adapter_names, {"text_encoder": ["adapter-1", "adapter-2"], "unet": ["adapter-1", "adapter-2"]}
+ )
- pipe.set_adapters(["adapter-1", "adapter-2"])
- self.assertDictEqual(
- pipe.get_list_adapters(),
- {"unet": ["adapter-1", "adapter-2"], "text_encoder": ["adapter-1", "adapter-2"]},
- )
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ self.assertDictEqual(
+ pipe.get_list_adapters(), {"unet": ["adapter-1", "adapter-2"], "text_encoder": ["adapter-1", "adapter-2"]}
+ )
- pipe.unet.add_adapter(unet_lora_config, "adapter-3")
- self.assertDictEqual(
- pipe.get_list_adapters(),
- {"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]},
- )
+ pipe.unet.add_adapter(unet_lora_config, "adapter-3")
+ self.assertDictEqual(
+ pipe.get_list_adapters(),
+ {"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]},
+ )
@unittest.skip("This is failing for now - need to investigate")
def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
@@ -1101,35 +873,32 @@ def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
- for scheduler_cls in [DDIMScheduler, LCMScheduler]:
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.unet.add_adapter(unet_lora_config)
- pipe.text_encoder.add_adapter(text_lora_config)
- pipe.unet.add_adapter(unet_lora_config)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
-
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
- pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+ pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
- if self.has_two_text_encoders:
- pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
- # Just makes sure it works..
- _ = pipe(**inputs, generator=torch.manual_seed(0)).images
+ # Just makes sure it works..
+ _ = pipe(**inputs, generator=torch.manual_seed(0)).images
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
@@ -1304,6 +1073,7 @@ def test_integration_logits_multi_adapter(self):
expected_slice_scale = np.array([0.538, 0.539, 0.540, 0.540, 0.542, 0.539, 0.538, 0.541, 0.539])
predicted_slice = images[0, -3:, -3:, -1].flatten()
+ # import pdb; pdb.set_trace()
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
@@ -1336,7 +1106,7 @@ def test_integration_logits_multi_adapter(self):
output_type="np",
).images
predicted_slice = images[0, -3:, -3:, -1].flatten()
- expected_slice_scale = np.array([0.5888, 0.5897, 0.5946, 0.5888, 0.5935, 0.5946, 0.5857, 0.5891, 0.5909])
+ expected_slice_scale = np.array([0.5977, 0.5985, 0.6039, 0.5976, 0.6025, 0.6036, 0.5946, 0.5979, 0.5998])
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
# Lora disabled
@@ -1350,7 +1120,7 @@ def test_integration_logits_multi_adapter(self):
output_type="np",
).images
predicted_slice = images[0, -3:, -3:, -1].flatten()
- expected_slice_scale = np.array([0.5456, 0.5466, 0.5487, 0.5458, 0.5469, 0.5454, 0.5446, 0.5479, 0.5487])
+ expected_slice_scale = np.array([0.54625, 0.5473, 0.5495, 0.5465, 0.5476, 0.5461, 0.5452, 0.5485, 0.5493])
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
@@ -1731,97 +1501,6 @@ def test_sdxl_1_0_lora(self):
self.assertTrue(np.allclose(images, expected, atol=1e-4))
release_memory(pipe)
- def test_sdxl_lcm_lora(self):
- pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
- pipe.enable_model_cpu_offload()
-
- generator = torch.Generator().manual_seed(0)
-
- lora_model_id = "latent-consistency/lcm-lora-sdxl"
-
- pipe.load_lora_weights(lora_model_id)
-
- image = pipe(
- "masterpiece, best quality, mountain", generator=generator, num_inference_steps=4, guidance_scale=0.5
- ).images[0]
-
- expected_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_lora/sdxl_lcm_lora.png"
- )
-
- image_np = pipe.image_processor.pil_to_numpy(image)
- expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
-
- self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
-
- pipe.unload_lora_weights()
-
- release_memory(pipe)
-
- def test_sdv1_5_lcm_lora(self):
- pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
- pipe.to("cuda")
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
-
- generator = torch.Generator().manual_seed(0)
-
- lora_model_id = "latent-consistency/lcm-lora-sdv1-5"
- pipe.load_lora_weights(lora_model_id)
-
- image = pipe(
- "masterpiece, best quality, mountain", generator=generator, num_inference_steps=4, guidance_scale=0.5
- ).images[0]
-
- expected_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_lora/sdv15_lcm_lora.png"
- )
-
- image_np = pipe.image_processor.pil_to_numpy(image)
- expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
-
- self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
-
- pipe.unload_lora_weights()
-
- release_memory(pipe)
-
- def test_sdv1_5_lcm_lora_img2img(self):
- pipe = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
- pipe.to("cuda")
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
-
- init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape.png"
- )
-
- generator = torch.Generator().manual_seed(0)
-
- lora_model_id = "latent-consistency/lcm-lora-sdv1-5"
- pipe.load_lora_weights(lora_model_id)
-
- image = pipe(
- "snowy mountain",
- generator=generator,
- image=init_image,
- strength=0.5,
- num_inference_steps=4,
- guidance_scale=0.5,
- ).images[0]
-
- expected_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_lora/sdv15_lcm_lora_img2img.png"
- )
-
- image_np = pipe.image_processor.pil_to_numpy(image)
- expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
-
- self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
-
- pipe.unload_lora_weights()
-
- release_memory(pipe)
-
def test_sdxl_1_0_lora_fusion(self):
generator = torch.Generator().manual_seed(0)
@@ -2021,28 +1700,6 @@ def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self):
self.assertTrue(np.allclose(images, expected, atol=1e-3))
release_memory(pipe)
- def test_sd_load_civitai_empty_network_alpha(self):
- """
- This test simply checks that loading a LoRA with an empty network alpha works fine
- See: https://github.com/huggingface/diffusers/issues/5606
- """
- pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda")
- pipeline.enable_sequential_cpu_offload()
- civitai_path = hf_hub_download("ybelkada/test-ahi-civitai", "ahi_lora_weights.safetensors")
- pipeline.load_lora_weights(civitai_path, adapter_name="ahri")
-
- images = pipeline(
- "ahri, masterpiece, league of legends",
- output_type="np",
- generator=torch.manual_seed(156),
- num_inference_steps=5,
- ).images
- images = images[0, -3:, -3:, -1].flatten()
- expected = np.array([0.0, 0.0, 0.0, 0.002557, 0.020954, 0.001792, 0.006581, 0.00591, 0.002995])
-
- self.assertTrue(np.allclose(images, expected, atol=1e-3))
- release_memory(pipeline)
-
def test_canny_lora(self):
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0")
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index 961147839461..80c97978723c 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -196,15 +196,11 @@ def test_forward_with_norm_groups(self):
class ModelTesterMixin:
main_input_name = None # overwrite in model specific tester class
base_precision = 1e-3
- forward_requires_fresh_args = False
def test_from_save_pretrained(self, expected_max_diff=5e-5):
- if self.forward_requires_fresh_args:
- model = self.model_class(**self.init_dict)
- else:
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor()
model.to(torch_device)
@@ -218,18 +214,11 @@ def test_from_save_pretrained(self, expected_max_diff=5e-5):
new_model.to(torch_device)
with torch.no_grad():
- if self.forward_requires_fresh_args:
- image = model(**self.inputs_dict(0))
- else:
- image = model(**inputs_dict)
-
+ image = model(**inputs_dict)
if isinstance(image, dict):
image = image.to_tuple()[0]
- if self.forward_requires_fresh_args:
- new_image = new_model(**self.inputs_dict(0))
- else:
- new_image = new_model(**inputs_dict)
+ new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
@@ -286,11 +275,8 @@ def test_getattr_is_correct(self):
)
def test_set_xformers_attn_processor_for_determinism(self):
torch.use_deterministic_algorithms(False)
- if self.forward_requires_fresh_args:
- model = self.model_class(**self.init_dict)
- else:
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
@@ -300,42 +286,20 @@ def test_set_xformers_attn_processor_for_determinism(self):
model.set_default_attn_processor()
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
- if self.forward_requires_fresh_args:
- output = model(**self.inputs_dict(0))[0]
- else:
- output = model(**inputs_dict)[0]
+ output = model(**inputs_dict)[0]
model.enable_xformers_memory_efficient_attention()
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
- if self.forward_requires_fresh_args:
- output_2 = model(**self.inputs_dict(0))[0]
- else:
- output_2 = model(**inputs_dict)[0]
-
- model.set_attn_processor(XFormersAttnProcessor())
- assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
- with torch.no_grad():
- if self.forward_requires_fresh_args:
- output_3 = model(**self.inputs_dict(0))[0]
- else:
- output_3 = model(**inputs_dict)[0]
-
- torch.use_deterministic_algorithms(True)
+ output_2 = model(**inputs_dict)[0]
assert torch.allclose(output, output_2, atol=self.base_precision)
- assert torch.allclose(output, output_3, atol=self.base_precision)
- assert torch.allclose(output_2, output_3, atol=self.base_precision)
@require_torch_gpu
def test_set_attn_processor_for_determinism(self):
torch.use_deterministic_algorithms(False)
- if self.forward_requires_fresh_args:
- model = self.model_class(**self.init_dict)
- else:
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
-
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
@@ -344,34 +308,32 @@ def test_set_attn_processor_for_determinism(self):
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad():
- if self.forward_requires_fresh_args:
- output_1 = model(**self.inputs_dict(0))[0]
- else:
- output_1 = model(**inputs_dict)[0]
+ output_1 = model(**inputs_dict)[0]
model.set_default_attn_processor()
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
- if self.forward_requires_fresh_args:
- output_2 = model(**self.inputs_dict(0))[0]
- else:
- output_2 = model(**inputs_dict)[0]
+ output_2 = model(**inputs_dict)[0]
+
+ model.enable_xformers_memory_efficient_attention()
+ assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
+ with torch.no_grad():
+ model(**inputs_dict)[0]
model.set_attn_processor(AttnProcessor2_0())
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad():
- if self.forward_requires_fresh_args:
- output_4 = model(**self.inputs_dict(0))[0]
- else:
- output_4 = model(**inputs_dict)[0]
+ output_4 = model(**inputs_dict)[0]
model.set_attn_processor(AttnProcessor())
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
- if self.forward_requires_fresh_args:
- output_5 = model(**self.inputs_dict(0))[0]
- else:
- output_5 = model(**inputs_dict)[0]
+ output_5 = model(**inputs_dict)[0]
+
+ model.set_attn_processor(XFormersAttnProcessor())
+ assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
+ with torch.no_grad():
+ output_6 = model(**inputs_dict)[0]
torch.use_deterministic_algorithms(True)
@@ -379,14 +341,12 @@ def test_set_attn_processor_for_determinism(self):
assert torch.allclose(output_2, output_1, atol=self.base_precision)
assert torch.allclose(output_2, output_4, atol=self.base_precision)
assert torch.allclose(output_2, output_5, atol=self.base_precision)
+ assert torch.allclose(output_2, output_6, atol=self.base_precision)
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
- if self.forward_requires_fresh_args:
- model = self.model_class(**self.init_dict)
- else:
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor()
@@ -409,17 +369,11 @@ def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
new_model.to(torch_device)
with torch.no_grad():
- if self.forward_requires_fresh_args:
- image = model(**self.inputs_dict(0))
- else:
- image = model(**inputs_dict)
+ image = model(**inputs_dict)
if isinstance(image, dict):
image = image.to_tuple()[0]
- if self.forward_requires_fresh_args:
- new_image = new_model(**self.inputs_dict(0))
- else:
- new_image = new_model(**inputs_dict)
+ new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
@@ -453,26 +407,17 @@ def test_from_save_pretrained_dtype(self):
assert new_model.dtype == dtype
def test_determinism(self, expected_max_diff=1e-5):
- if self.forward_requires_fresh_args:
- model = self.model_class(**self.init_dict)
- else:
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
- if self.forward_requires_fresh_args:
- first = model(**self.inputs_dict(0))
- else:
- first = model(**inputs_dict)
+ first = model(**inputs_dict)
if isinstance(first, dict):
first = first.to_tuple()[0]
- if self.forward_requires_fresh_args:
- second = model(**self.inputs_dict(0))
- else:
- second = model(**inputs_dict)
+ second = model(**inputs_dict)
if isinstance(second, dict):
second = second.to_tuple()[0]
@@ -605,22 +550,15 @@ def recursive_check(tuple_object, dict_object):
),
)
- if self.forward_requires_fresh_args:
- model = self.model_class(**self.init_dict)
- else:
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
- if self.forward_requires_fresh_args:
- outputs_dict = model(**self.inputs_dict(0))
- outputs_tuple = model(**self.inputs_dict(0), return_dict=False)
- else:
- outputs_dict = model(**inputs_dict)
- outputs_tuple = model(**inputs_dict, return_dict=False)
+ outputs_dict = model(**inputs_dict)
+ outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
diff --git a/tests/models/test_models_prior.py b/tests/models/test_models_prior.py
index 9b02de463ecd..4c47a44ef52a 100644
--- a/tests/models/test_models_prior.py
+++ b/tests/models/test_models_prior.py
@@ -162,8 +162,8 @@ def tearDown(self):
@parameterized.expand(
[
# fmt: off
- [13, [-0.5861, 0.1283, -0.0931, 0.0882, 0.4476, 0.1329, -0.0498, 0.0640]],
- [37, [-0.4913, 0.0110, -0.0483, 0.0541, 0.4954, -0.0170, 0.0354, 0.1651]],
+ [13, [-0.5861, 0.1283, -0.0931, 0.0882, 0.4476, 0.1329, -0.0498, 0.0640]],
+ [37, [-0.4913, 0.0110, -0.0483, 0.0541, 0.4954, -0.0170, 0.0354, 0.1651]],
# fmt: on
]
)
diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py
index 06bf2685560d..d8b412aa12d9 100644
--- a/tests/models/test_models_unet_2d_condition.py
+++ b/tests/models/test_models_unet_2d_condition.py
@@ -24,8 +24,7 @@
from pytest import mark
from diffusers import UNet2DConditionModel
-from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
-from diffusers.models.embeddings import ImageProjection
+from diffusers.models.attention_processor import CustomDiffusionAttnProcessor
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
@@ -46,57 +45,6 @@
enable_full_determinism()
-def create_ip_adapter_state_dict(model):
- # "ip_adapter" (cross-attention weights)
- ip_cross_attn_state_dict = {}
- key_id = 1
-
- for name in model.attn_processors.keys():
- cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
- if name.startswith("mid_block"):
- hidden_size = model.config.block_out_channels[-1]
- elif name.startswith("up_blocks"):
- block_id = int(name[len("up_blocks.")])
- hidden_size = list(reversed(model.config.block_out_channels))[block_id]
- elif name.startswith("down_blocks"):
- block_id = int(name[len("down_blocks.")])
- hidden_size = model.config.block_out_channels[block_id]
- if cross_attention_dim is not None:
- sd = IPAdapterAttnProcessor(
- hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
- ).state_dict()
- ip_cross_attn_state_dict.update(
- {
- f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"],
- f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"],
- }
- )
-
- key_id += 2
-
- # "image_proj" (ImageProjection layer weights)
- cross_attention_dim = model.config["cross_attention_dim"]
- image_projection = ImageProjection(
- cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, num_image_text_embeds=4
- )
-
- ip_image_projection_state_dict = {}
- sd = image_projection.state_dict()
- ip_image_projection_state_dict.update(
- {
- "proj.weight": sd["image_embeds.weight"],
- "proj.bias": sd["image_embeds.bias"],
- "norm.weight": sd["norm.weight"],
- "norm.bias": sd["norm.bias"],
- }
- )
-
- del sd
- ip_state_dict = {}
- ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
- return ip_state_dict
-
-
def create_custom_diffusion_layers(model, mock_weights: bool = True):
train_kv = True
train_q_out = True
@@ -658,72 +606,6 @@ def test_pickle(self):
assert (sample - sample_copy).abs().max() < 1e-4
- def test_asymmetrical_unet(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- # Add asymmetry to configs
- init_dict["transformer_layers_per_block"] = [[3, 2], 1]
- init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict)
- model.to(torch_device)
-
- output = model(**inputs_dict).sample
- expected_shape = inputs_dict["sample"].shape
-
- # Check if input and output shapes are the same
- self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
-
- def test_ip_adapter(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- init_dict["attention_head_dim"] = (8, 16)
-
- model = self.model_class(**init_dict)
- model.to(torch_device)
-
- # forward pass without ip-adapter
- with torch.no_grad():
- sample1 = model(**inputs_dict).sample
-
- # update inputs_dict for ip-adapter
- batch_size = inputs_dict["encoder_hidden_states"].shape[0]
- image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device)
- inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
-
- # make ip_adapter_1 and ip_adapter_2
- ip_adapter_1 = create_ip_adapter_state_dict(model)
-
- image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()}
- cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()}
- ip_adapter_2 = {}
- ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2})
-
- # forward pass ip_adapter_1
- model._load_ip_adapter_weights(ip_adapter_1)
- assert model.config.encoder_hid_dim_type == "ip_image_proj"
- assert model.encoder_hid_proj is not None
- assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in (
- "IPAdapterAttnProcessor",
- "IPAdapterAttnProcessor2_0",
- )
- with torch.no_grad():
- sample2 = model(**inputs_dict).sample
-
- # forward pass with ip_adapter_2
- model._load_ip_adapter_weights(ip_adapter_2)
- with torch.no_grad():
- sample3 = model(**inputs_dict).sample
-
- # forward pass with ip_adapter_1 again
- model._load_ip_adapter_weights(ip_adapter_1)
- with torch.no_grad():
- sample4 = model(**inputs_dict).sample
-
- assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4)
- assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
- assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
-
@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py
index aa755e387b61..fe2bcdb0af35 100644
--- a/tests/models/test_models_vae.py
+++ b/tests/models/test_models_vae.py
@@ -16,20 +16,11 @@
import gc
import unittest
-import numpy as np
import torch
from parameterized import parameterized
-from diffusers import (
- AsymmetricAutoencoderKL,
- AutoencoderKL,
- AutoencoderKLTemporalDecoder,
- AutoencoderTiny,
- ConsistencyDecoderVAE,
- StableDiffusionPipeline,
-)
+from diffusers import AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderTiny
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.loading_utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
@@ -39,7 +30,6 @@
torch_all_close,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
@@ -47,82 +37,6 @@
enable_full_determinism()
-def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
- block_out_channels = block_out_channels or [32, 64]
- norm_num_groups = norm_num_groups or 32
- init_dict = {
- "block_out_channels": block_out_channels,
- "in_channels": 3,
- "out_channels": 3,
- "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
- "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
- "latent_channels": 4,
- "norm_num_groups": norm_num_groups,
- }
- return init_dict
-
-
-def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
- block_out_channels = block_out_channels or [32, 64]
- norm_num_groups = norm_num_groups or 32
- init_dict = {
- "in_channels": 3,
- "out_channels": 3,
- "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
- "down_block_out_channels": block_out_channels,
- "layers_per_down_block": 1,
- "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
- "up_block_out_channels": block_out_channels,
- "layers_per_up_block": 1,
- "act_fn": "silu",
- "latent_channels": 4,
- "norm_num_groups": norm_num_groups,
- "sample_size": 32,
- "scaling_factor": 0.18215,
- }
- return init_dict
-
-
-def get_autoencoder_tiny_config(block_out_channels=None):
- block_out_channels = (len(block_out_channels) * [32]) if block_out_channels is not None else [32, 32]
- init_dict = {
- "in_channels": 3,
- "out_channels": 3,
- "encoder_block_out_channels": block_out_channels,
- "decoder_block_out_channels": block_out_channels,
- "num_encoder_blocks": [b // min(block_out_channels) for b in block_out_channels],
- "num_decoder_blocks": [b // min(block_out_channels) for b in reversed(block_out_channels)],
- }
- return init_dict
-
-
-def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None):
- block_out_channels = block_out_channels or [32, 64]
- norm_num_groups = norm_num_groups or 32
- return {
- "encoder_block_out_channels": block_out_channels,
- "encoder_in_channels": 3,
- "encoder_out_channels": 4,
- "encoder_down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
- "decoder_add_attention": False,
- "decoder_block_out_channels": block_out_channels,
- "decoder_down_block_types": ["ResnetDownsampleBlock2D"] * len(block_out_channels),
- "decoder_downsample_padding": 1,
- "decoder_in_channels": 7,
- "decoder_layers_per_block": 1,
- "decoder_norm_eps": 1e-05,
- "decoder_norm_num_groups": norm_num_groups,
- "encoder_norm_num_groups": norm_num_groups,
- "decoder_num_train_timesteps": 1024,
- "decoder_out_channels": 6,
- "decoder_resnet_time_scale_shift": "scale_shift",
- "decoder_time_embedding_type": "learned",
- "decoder_up_block_types": ["ResnetUpsampleBlock2D"] * len(block_out_channels),
- "scaling_factor": 1,
- "latent_channels": 4,
- }
-
-
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
@@ -147,7 +61,14 @@ def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
- init_dict = get_autoencoder_kl_config()
+ init_dict = {
+ "block_out_channels": [32, 64],
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ "latent_channels": 4,
+ }
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@@ -249,31 +170,11 @@ def test_output_pretrained(self):
)
elif torch_device == "cpu":
expected_output_slice = torch.tensor(
- [
- -0.1352,
- 0.0878,
- 0.0419,
- -0.0818,
- -0.1069,
- 0.0688,
- -0.1458,
- -0.4446,
- -0.0026,
- ]
+ [-0.1352, 0.0878, 0.0419, -0.0818, -0.1069, 0.0688, -0.1458, -0.4446, -0.0026]
)
else:
expected_output_slice = torch.tensor(
- [
- -0.2421,
- 0.4642,
- 0.2507,
- -0.0438,
- 0.0682,
- 0.3160,
- -0.2018,
- -0.0727,
- 0.2485,
- ]
+ [-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485]
)
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
@@ -304,7 +205,21 @@ def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
- init_dict = get_asym_autoencoder_kl_config()
+ init_dict = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ "down_block_out_channels": [32, 64],
+ "layers_per_down_block": 1,
+ "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ "up_block_out_channels": [32, 64],
+ "layers_per_up_block": 1,
+ "act_fn": "silu",
+ "latent_channels": 4,
+ "norm_num_groups": 32,
+ "sample_size": 32,
+ "scaling_factor": 0.18215,
+ }
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@@ -338,139 +253,21 @@ def input_shape(self):
def output_shape(self):
return (3, 32, 32)
- def prepare_init_args_and_inputs_for_common(self):
- init_dict = get_autoencoder_tiny_config()
- inputs_dict = self.dummy_input
- return init_dict, inputs_dict
-
- def test_outputs_equivalence(self):
- pass
-
-
-class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
- model_class = ConsistencyDecoderVAE
- main_input_name = "sample"
- base_precision = 1e-2
- forward_requires_fresh_args = True
-
- def inputs_dict(self, seed=None):
- generator = torch.Generator("cpu")
- if seed is not None:
- generator.manual_seed(0)
- image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
-
- return {"sample": image, "generator": generator}
-
- @property
- def input_shape(self):
- return (3, 32, 32)
-
- @property
- def output_shape(self):
- return (3, 32, 32)
-
- @property
- def init_dict(self):
- return get_consistency_vae_config()
-
- def prepare_init_args_and_inputs_for_common(self):
- return self.init_dict, self.inputs_dict()
-
- @unittest.skip
- def test_training(self):
- ...
-
- @unittest.skip
- def test_ema_training(self):
- ...
-
-
-class AutoncoderKLTemporalDecoderFastTests(ModelTesterMixin, unittest.TestCase):
- model_class = AutoencoderKLTemporalDecoder
- main_input_name = "sample"
- base_precision = 1e-2
-
- @property
- def dummy_input(self):
- batch_size = 3
- num_channels = 3
- sizes = (32, 32)
-
- image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
- num_frames = 3
-
- return {"sample": image, "num_frames": num_frames}
-
- @property
- def input_shape(self):
- return (3, 32, 32)
-
- @property
- def output_shape(self):
- return (3, 32, 32)
-
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
- "block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
- "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
- "latent_channels": 4,
- "layers_per_block": 2,
+ "encoder_block_out_channels": (32, 32),
+ "decoder_block_out_channels": (32, 32),
+ "num_encoder_blocks": (1, 2),
+ "num_decoder_blocks": (2, 1),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
- def test_forward_signature(self):
- pass
-
- def test_training(self):
+ def test_outputs_equivalence(self):
pass
- @unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
- def test_gradient_checkpointing(self):
- # enable deterministic behavior for gradient checkpointing
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
- model.to(torch_device)
-
- assert not model.is_gradient_checkpointing and model.training
-
- out = model(**inputs_dict).sample
- # run the backwards pass on the model. For backwards pass, for simplicity purpose,
- # we won't calculate the loss and rather backprop on out.sum()
- model.zero_grad()
-
- labels = torch.randn_like(out)
- loss = (out - labels).mean()
- loss.backward()
-
- # re-instantiate the model now enabling gradient checkpointing
- model_2 = self.model_class(**init_dict)
- # clone model
- model_2.load_state_dict(model.state_dict())
- model_2.to(torch_device)
- model_2.enable_gradient_checkpointing()
-
- assert model_2.is_gradient_checkpointing and model_2.training
-
- out_2 = model_2(**inputs_dict).sample
- # run the backwards pass on the model. For backwards pass, for simplicity purpose,
- # we won't calculate the loss and rather backprop on out.sum()
- model_2.zero_grad()
- loss_2 = (out_2 - labels).mean()
- loss_2.backward()
-
- # compare the output and parameters gradients
- self.assertTrue((loss - loss_2).abs() < 1e-5)
- named_params = dict(model.named_parameters())
- named_params_2 = dict(model_2.named_parameters())
- for name, param in named_params.items():
- if "post_quant_conv" in name:
- continue
-
- self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
-
@slow
class AutoencoderTinyIntegrationTests(unittest.TestCase):
@@ -587,16 +384,8 @@ def get_generator(self, seed=0):
@parameterized.expand(
[
# fmt: off
- [
- 33,
- [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824],
- [-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824],
- ],
- [
- 47,
- [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089],
- [0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131],
- ],
+ [33, [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824], [-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824]],
+ [47, [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089], [0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131]],
# fmt: on
]
)
@@ -642,16 +431,8 @@ def test_stable_diffusion_fp16(self, seed, expected_slice):
@parameterized.expand(
[
# fmt: off
- [
- 33,
- [-0.1609, 0.9866, -0.0487, -0.0777, -0.2716, 0.8368, -0.2055, -0.0814],
- [-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824],
- ],
- [
- 47,
- [-0.2377, 0.1147, 0.1333, -0.4841, -0.2506, -0.0805, -0.0491, -0.4085],
- [0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131],
- ],
+ [33, [-0.1609, 0.9866, -0.0487, -0.0777, -0.2716, 0.8368, -0.2055, -0.0814], [-0.2395, 0.0098, 0.0102, -0.0709, -0.2840, -0.0274, -0.0718, -0.1824]],
+ [47, [-0.2377, 0.1147, 0.1333, -0.4841, -0.2506, -0.0805, -0.0491, -0.4085], [0.0350, 0.0847, 0.0467, 0.0344, -0.0842, -0.0547, -0.0633, -0.1131]],
# fmt: on
]
)
@@ -717,10 +498,7 @@ def test_stable_diffusion_decode_fp16(self, seed, expected_slice):
@parameterized.expand([(13,), (16,), (27,)])
@require_torch_gpu
- @unittest.skipIf(
- not is_xformers_available(),
- reason="xformers is not required when using PyTorch 2.0.",
- )
+ @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
model = self.get_sd_vae_model(fp16=True)
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
@@ -738,10 +516,7 @@ def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
- @unittest.skipIf(
- not is_xformers_available(),
- reason="xformers is not required when using PyTorch 2.0.",
- )
+ @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
@@ -839,16 +614,8 @@ def get_generator(self, seed=0):
@parameterized.expand(
[
# fmt: off
- [
- 33,
- [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078],
- [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824],
- ],
- [
- 47,
- [0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529],
- [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089],
- ],
+ [33, [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078], [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824]],
+ [47, [0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529], [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089]],
# fmt: on
]
)
@@ -870,16 +637,8 @@ def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
@parameterized.expand(
[
# fmt: off
- [
- 33,
- [-0.0340, 0.2870, 0.1698, -0.0105, -0.3448, 0.3529, -0.1321, 0.1097],
- [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078],
- ],
- [
- 47,
- [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531],
- [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531],
- ],
+ [33, [-0.0340, 0.2870, 0.1698, -0.0105, -0.3448, 0.3529, -0.1321, 0.1097], [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078]],
+ [47, [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531], [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531]],
# fmt: on
]
)
@@ -900,7 +659,7 @@ def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
@parameterized.expand(
[
# fmt: off
- [13, [-0.0521, -0.2939, 0.1540, -0.1855, -0.5936, -0.3138, -0.4579, -0.2275]],
+ [13, [-0.0521, -0.2939, 0.1540, -0.1855, -0.5936, -0.3138, -0.4579, -0.2275]],
[37, [-0.1820, -0.4345, -0.0455, -0.2923, -0.8035, -0.5089, -0.4795, -0.3106]],
# fmt: on
]
@@ -922,10 +681,7 @@ def test_stable_diffusion_decode(self, seed, expected_slice):
@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
- @unittest.skipIf(
- not is_xformers_available(),
- reason="xformers is not required when using PyTorch 2.0.",
- )
+ @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
@@ -965,106 +721,3 @@ def test_stable_diffusion_encode_sample(self, seed, expected_slice):
tolerance = 3e-3 if torch_device != "mps" else 1e-2
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
-
-
-@slow
-class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- @torch.no_grad()
- def test_encode_decode(self):
- vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
- vae.to(torch_device)
-
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/img2img/sketch-mountains-input.jpg"
- ).resize((256, 256))
- image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[
- None, :, :, :
- ].cuda()
-
- latent = vae.encode(image).latent_dist.mean
-
- sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
-
- actual_output = sample[0, :2, :2, :2].flatten().cpu()
- expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])
-
- assert torch_all_close(actual_output, expected_output, atol=5e-3)
-
- def test_sd(self):
- vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
- pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
- pipe.to(torch_device)
-
- out = pipe(
- "horse",
- num_inference_steps=2,
- output_type="pt",
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
-
- actual_output = out[:2, :2, :2].flatten().cpu()
- expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])
-
- assert torch_all_close(actual_output, expected_output, atol=5e-3)
-
- def test_encode_decode_f16(self):
- vae = ConsistencyDecoderVAE.from_pretrained(
- "openai/consistency-decoder", torch_dtype=torch.float16
- ) # TODO - update
- vae.to(torch_device)
-
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/img2img/sketch-mountains-input.jpg"
- ).resize((256, 256))
- image = (
- torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
- .half()
- .cuda()
- )
-
- latent = vae.encode(image).latent_dist.mean
-
- sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
-
- actual_output = sample[0, :2, :2, :2].flatten().cpu()
- expected_output = torch.tensor(
- [-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471],
- dtype=torch.float16,
- )
-
- assert torch_all_close(actual_output, expected_output, atol=5e-3)
-
- def test_sd_f16(self):
- vae = ConsistencyDecoderVAE.from_pretrained(
- "openai/consistency-decoder", torch_dtype=torch.float16
- ) # TODO - update
- pipe = StableDiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5",
- torch_dtype=torch.float16,
- vae=vae,
- safety_checker=None,
- )
- pipe.to(torch_device)
-
- out = pipe(
- "horse",
- num_inference_steps=2,
- output_type="pt",
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
-
- actual_output = out[:2, :2, :2].flatten().cpu()
- expected_output = torch.tensor(
- [0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035],
- dtype=torch.float16,
- )
-
- assert torch_all_close(actual_output, expected_output, atol=5e-3)
diff --git a/tests/others/test_check_copies.py b/tests/others/test_check_copies.py
index b611fd7d19d7..3fdf7dfe8d1a 100644
--- a/tests/others/test_check_copies.py
+++ b/tests/others/test_check_copies.py
@@ -19,6 +19,8 @@
import tempfile
import unittest
+import black
+
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(git_repo_path, "utils"))
@@ -63,7 +65,8 @@ def check_copy_consistency(self, comment, class_name, class_code, overwrite_resu
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
if overwrite_result is not None:
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
- code = check_copies.run_ruff(code)
+ mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
+ code = black.format_str(code, mode=mode)
fname = os.path.join(self.diffusers_dir, "new_code.py")
with open(fname, "w", newline="\n") as f:
f.write(code)
diff --git a/tests/others/test_outputs.py b/tests/others/test_outputs.py
index cf709d93f709..492e71f0ba31 100644
--- a/tests/others/test_outputs.py
+++ b/tests/others/test_outputs.py
@@ -7,7 +7,6 @@
import PIL.Image
from diffusers.utils.outputs import BaseOutput
-from diffusers.utils.testing_utils import require_torch
@dataclass
@@ -70,24 +69,3 @@ def test_outputs_serialization(self):
assert dir(outputs_orig) == dir(outputs_copy)
assert dict(outputs_orig) == dict(outputs_copy)
assert vars(outputs_orig) == vars(outputs_copy)
-
- @require_torch
- def test_torch_pytree(self):
- # ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
- # this is important for DistributedDataParallel gradient synchronization with static_graph=True
- import torch
- import torch.utils._pytree
-
- data = np.random.rand(1, 3, 4, 4)
- x = CustomOutput(images=data)
- self.assertFalse(torch.utils._pytree._is_leaf(x))
-
- expected_flat_outs = [data]
- expected_tree_spec = torch.utils._pytree.TreeSpec(CustomOutput, ["images"], [torch.utils._pytree.LeafSpec()])
-
- actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x)
- self.assertEqual(expected_flat_outs, actual_flat_outs)
- self.assertEqual(expected_tree_spec, actual_tree_spec)
-
- unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
- self.assertEqual(x, unflattened_x)
diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion.py b/tests/pipelines/altdiffusion/test_alt_diffusion.py
index b4a2847bb84d..da5eb34fe92f 100644
--- a/tests/pipelines/altdiffusion/test_alt_diffusion.py
+++ b/tests/pipelines/altdiffusion/test_alt_diffusion.py
@@ -27,12 +27,7 @@
)
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
-from ..pipeline_params import (
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
@@ -47,7 +42,6 @@ class AltDiffusionPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
@@ -117,7 +111,6 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py
index 3fd1a90172ca..57001f7bea52 100644
--- a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py
+++ b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py
@@ -141,7 +141,6 @@ def test_stable_diffusion_img2img_default_case(self):
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
- image_encoder=None,
)
alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=True)
alt_pipe = alt_pipe.to(device)
@@ -206,7 +205,6 @@ def test_stable_diffusion_img2img_fp16(self):
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
- image_encoder=None,
)
alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False)
alt_pipe = alt_pipe.to(torch_device)
diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py
index ce8693343043..64baeea910b8 100644
--- a/tests/pipelines/controlnet/test_controlnet.py
+++ b/tests/pipelines/controlnet/test_controlnet.py
@@ -27,7 +27,6 @@
ControlNetModel,
DDIMScheduler,
EulerDiscreteScheduler,
- LCMScheduler,
StableDiffusionControlNetPipeline,
UNet2DConditionModel,
)
@@ -117,7 +116,7 @@ class ControlNetPipelineFastTests(
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- def get_dummy_components(self, time_cond_proj_dim=None):
+ def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
@@ -129,7 +128,6 @@ def get_dummy_components(self, time_cond_proj_dim=None):
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
norm_num_groups=1,
- time_cond_proj_dim=time_cond_proj_dim,
)
torch.manual_seed(0)
controlnet = ControlNetModel(
@@ -183,7 +181,6 @@ def get_dummy_components(self, time_cond_proj_dim=None):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
@@ -224,52 +221,6 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
- def test_controlnet_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionControlNetPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array(
- [0.52700454, 0.3930534, 0.25509018, 0.7132304, 0.53696585, 0.46568912, 0.7095368, 0.7059624, 0.4744786]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_controlnet_lcm_custom_timesteps(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionControlNetPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- del inputs["num_inference_steps"]
- inputs["timesteps"] = [999, 499]
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array(
- [0.52700454, 0.3930534, 0.25509018, 0.7132304, 0.53696585, 0.46568912, 0.7095368, 0.7059624, 0.4744786]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
class StableDiffusionMultiControlNetPipelineFastTests(
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
@@ -366,7 +317,6 @@ def init_weights(m):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
@@ -544,7 +494,6 @@ def init_weights(m):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py
index 5a7f70eb488a..3113836f5d0a 100644
--- a/tests/pipelines/controlnet/test_controlnet_img2img.py
+++ b/tests/pipelines/controlnet/test_controlnet_img2img.py
@@ -72,7 +72,7 @@ class ControlNetImg2ImgPipelineFastTests(
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
- block_out_channels=(4, 8),
+ block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
@@ -80,17 +80,15 @@ def get_dummy_components(self):
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
- norm_num_groups=1,
)
torch.manual_seed(0)
controlnet = ControlNetModel(
- block_out_channels=(4, 8),
+ block_out_channels=(32, 64),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
- norm_num_groups=1,
)
torch.manual_seed(0)
scheduler = DDIMScheduler(
@@ -102,13 +100,12 @@ def get_dummy_components(self):
)
torch.manual_seed(0)
vae = AutoencoderKL(
- block_out_channels=[4, 8],
+ block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
- norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
@@ -189,7 +186,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
- block_out_channels=(4, 8),
+ block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
@@ -197,7 +194,6 @@ def get_dummy_components(self):
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
- norm_num_groups=1,
)
torch.manual_seed(0)
@@ -207,25 +203,23 @@ def init_weights(m):
m.bias.data.fill_(1.0)
controlnet1 = ControlNetModel(
- block_out_channels=(4, 8),
+ block_out_channels=(32, 64),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
- norm_num_groups=1,
)
controlnet1.controlnet_down_blocks.apply(init_weights)
torch.manual_seed(0)
controlnet2 = ControlNetModel(
- block_out_channels=(4, 8),
+ block_out_channels=(32, 64),
layers_per_block=2,
in_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=32,
conditioning_embedding_out_channels=(16, 32),
- norm_num_groups=1,
)
controlnet2.controlnet_down_blocks.apply(init_weights)
@@ -239,13 +233,12 @@ def init_weights(m):
)
torch.manual_seed(0)
vae = AutoencoderKL(
- block_out_channels=[4, 8],
+ block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
- norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py
index 7c3371c197d4..a9140f3d5a31 100644
--- a/tests/pipelines/controlnet/test_controlnet_inpaint.py
+++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py
@@ -132,7 +132,6 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
@@ -249,7 +248,6 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
@@ -344,7 +342,6 @@ def init_weights(m):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py
index ba129e763c22..4fff88434bc3 100644
--- a/tests/pipelines/controlnet/test_controlnet_sdxl.py
+++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py
@@ -24,11 +24,9 @@
AutoencoderKL,
ControlNetModel,
EulerDiscreteScheduler,
- LCMScheduler,
StableDiffusionXLControlNetPipeline,
UNet2DConditionModel,
)
-from diffusers.models.unet_2d_blocks import UNetMidBlock2D
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device
@@ -44,7 +42,6 @@
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
)
@@ -52,11 +49,7 @@
class StableDiffusionXLControlNetPipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
- unittest.TestCase,
+ PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS
@@ -64,7 +57,7 @@ class StableDiffusionXLControlNetPipelineFastTests(
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- def get_dummy_components(self, time_cond_proj_dim=None):
+ def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
@@ -82,7 +75,6 @@ def get_dummy_components(self, time_cond_proj_dim=None):
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
- time_cond_proj_dim=time_cond_proj_dim,
)
torch.manual_seed(0)
controlnet = ControlNetModel(
@@ -147,8 +139,6 @@ def get_dummy_components(self, time_cond_proj_dim=None):
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
- "feature_extractor": None,
- "image_encoder": None,
}
return components
@@ -189,9 +179,6 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
@require_torch_gpu
def test_stable_diffusion_xl_offloads(self):
pipes = []
@@ -335,29 +322,9 @@ def test_controlnet_sdxl_guess(self):
# make sure that it's equal
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4
- def test_controlnet_sdxl_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionXLControlNetPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.7799, 0.614, 0.6162, 0.7082, 0.6662, 0.5833, 0.4148, 0.5182, 0.4866])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
class StableDiffusionXLMultiControlNetPipelineFastTests(
- PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
+ PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS
@@ -474,8 +441,6 @@ def init_weights(m):
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
- "feature_extractor": None,
- "image_encoder": None,
}
return components
@@ -505,7 +470,7 @@ def get_dummy_inputs(self, device, seed=0):
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
- "output_type": "np",
+ "output_type": "numpy",
"image": images,
}
@@ -557,12 +522,9 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
- def test_save_load_optional_components(self):
- return self._test_save_load_optional_components()
-
class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
- PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
+ PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLControlNetPipeline
params = TEXT_TO_IMAGE_PARAMS
@@ -661,8 +623,6 @@ def init_weights(m):
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
- "feature_extractor": None,
- "image_encoder": None,
}
return components
@@ -686,7 +646,7 @@ def get_dummy_inputs(self, device, seed=0):
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
- "output_type": "np",
+ "output_type": "numpy",
"image": images,
}
@@ -742,9 +702,6 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
def test_negative_conditions(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -818,162 +775,3 @@ def test_depth(self):
original_image = images[0, -3:, -3:, -1].flatten()
expected_image = np.array([0.4399, 0.5112, 0.5478, 0.4314, 0.472, 0.4823, 0.4647, 0.4957, 0.4853])
assert np.allclose(original_image, expected_image, atol=1e-04)
-
-
-class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNetPipelineFastTests):
- def test_controlnet_sdxl_guess(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- sd_pipe = self.pipeline_class(**components)
- sd_pipe = sd_pipe.to(device)
-
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["guess_mode"] = True
-
- output = sd_pipe(**inputs)
- image_slice = output.images[0, -3:, -3:, -1]
- expected_slice = np.array(
- [0.6831671, 0.5702532, 0.5459845, 0.6299793, 0.58563006, 0.6033695, 0.4493941, 0.46132287, 0.5035841]
- )
-
- # make sure that it's equal
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4
-
- def test_controlnet_sdxl_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionXLControlNetPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.6850, 0.5135, 0.5545, 0.7033, 0.6617, 0.5971, 0.4165, 0.5480, 0.5070])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_conditioning_channels(self):
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- mid_block_type="UNetMidBlock2D",
- # SD2-specific config below
- attention_head_dim=(2, 4),
- use_linear_projection=True,
- addition_embed_type="text_time",
- addition_time_embed_dim=8,
- transformer_layers_per_block=(1, 2),
- projection_class_embeddings_input_dim=80, # 6 * 8 + 32
- cross_attention_dim=64,
- time_cond_proj_dim=None,
- )
-
- controlnet = ControlNetModel.from_unet(unet, conditioning_channels=4)
- assert type(controlnet.mid_block) == UNetMidBlock2D
- assert controlnet.conditioning_channels == 4
-
- def get_dummy_components(self, time_cond_proj_dim=None):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- mid_block_type="UNetMidBlock2D",
- # SD2-specific config below
- attention_head_dim=(2, 4),
- use_linear_projection=True,
- addition_embed_type="text_time",
- addition_time_embed_dim=8,
- transformer_layers_per_block=(1, 2),
- projection_class_embeddings_input_dim=80, # 6 * 8 + 32
- cross_attention_dim=64,
- time_cond_proj_dim=time_cond_proj_dim,
- )
- torch.manual_seed(0)
- controlnet = ControlNetModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- in_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- conditioning_embedding_out_channels=(16, 32),
- mid_block_type="UNetMidBlock2D",
- # SD2-specific config below
- attention_head_dim=(2, 4),
- use_linear_projection=True,
- addition_embed_type="text_time",
- addition_time_embed_dim=8,
- transformer_layers_per_block=(1, 2),
- projection_class_embeddings_input_dim=80, # 6 * 8 + 32
- cross_attention_dim=64,
- )
- torch.manual_seed(0)
- scheduler = EulerDiscreteScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- steps_offset=1,
- beta_schedule="scaled_linear",
- timestep_spacing="leading",
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=32,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "controlnet": controlnet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "text_encoder_2": text_encoder_2,
- "tokenizer_2": tokenizer_2,
- "feature_extractor": None,
- "image_encoder": None,
- }
- return components
diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py
index da037109ae8f..5dc5fe740317 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py
@@ -134,7 +134,7 @@ def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
def test_float16_inference(self):
- super().test_float16_inference(expected_max_diff=2e-1)
+ super().test_float16_inference(expected_max_diff=1e-1)
def test_dict_tuple_outputs_equivalent(self):
super().test_dict_tuple_outputs_equivalent(expected_max_difference=5e-4)
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky.py b/tests/pipelines/kandinsky2_2/test_kandinsky.py
index 64117b91fc03..65dbf0a708eb 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky.py
@@ -172,7 +172,6 @@ class KandinskyV22PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"output_type",
"return_dict",
]
- callback_cfg_params = ["image_embds"]
test_xformers_attention = False
def get_dummy_inputs(self, device, seed=0):
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
index 2b7c1642b395..42c78bfc1af3 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
@@ -55,7 +55,6 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa
"return_dict",
]
test_xformers_attention = True
- callback_cfg_params = ["image_embds"]
def get_dummy_components(self):
dummy = Dummies()
@@ -153,12 +152,6 @@ def test_save_load_local(self):
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-3)
- def test_callback_inputs(self):
- pass
-
- def test_callback_cfg(self):
- pass
-
class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22Img2ImgCombinedPipeline
@@ -179,7 +172,6 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest
"return_dict",
]
test_xformers_attention = False
- callback_cfg_params = ["image_embds"]
def get_dummy_components(self):
dummy = Img2ImgDummies()
@@ -275,12 +267,6 @@ def test_save_load_optional_components(self):
def save_load_local(self):
super().test_save_load_local(expected_max_difference=5e-3)
- def test_callback_inputs(self):
- pass
-
- def test_callback_cfg(self):
- pass
-
class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22InpaintCombinedPipeline
@@ -398,9 +384,3 @@ def test_save_load_optional_components(self):
def test_sequential_cpu_offload_forward_pass(self):
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4)
-
- def test_callback_inputs(self):
- pass
-
- def test_callback_cfg(self):
- pass
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
index 215f284d65db..9a5b596def58 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
@@ -192,7 +192,6 @@ class KandinskyV22Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCas
"return_dict",
]
test_xformers_attention = False
- callback_cfg_params = ["image_embeds"]
def get_dummy_components(self):
dummies = Dummies()
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
index 4225441ecee4..f40ec0d1f070 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
@@ -194,7 +194,6 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
"return_dict",
]
test_xformers_attention = False
- callback_cfg_params = ["image_embeds", "masked_image", "mask_image"]
def get_dummy_components(self):
dummies = Dummies()
@@ -253,40 +252,6 @@ def test_save_load_optional_components(self):
def test_sequential_cpu_offload_forward_pass(self):
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4)
- # override default test because we need to zero out mask too in order to make sure final latent is all zero
- def test_callback_inputs(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- self.assertTrue(
- hasattr(pipe, "_callback_tensor_inputs"),
- f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
- )
-
- def callback_inputs_test(pipe, i, t, callback_kwargs):
- missing_callback_inputs = set()
- for v in pipe._callback_tensor_inputs:
- if v not in callback_kwargs:
- missing_callback_inputs.add(v)
- self.assertTrue(
- len(missing_callback_inputs) == 0, f"Missing callback tensor inputs: {missing_callback_inputs}"
- )
- last_i = pipe.num_timesteps - 1
- if i == last_i:
- callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
- callback_kwargs["mask_image"] = torch.zeros_like(callback_kwargs["mask_image"])
- return callback_kwargs
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["callback_on_step_end"] = callback_inputs_test
- inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
- inputs["output_type"] = "latent"
-
- output = pipe(**inputs)[0]
- assert output.abs().sum() == 0
-
@slow
@require_torch_gpu
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
index 6b53910e5633..a0de5cceeb75 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import unittest
import numpy as np
@@ -183,7 +182,6 @@ class KandinskyV22PriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
"output_type",
"return_dict",
]
- callback_cfg_params = ["prompt_embeds", "text_encoder_hidden_states", "text_mask"]
test_xformers_attention = False
def get_dummy_components(self):
@@ -237,42 +235,3 @@ def test_attention_slicing_forward_pass(self):
test_max_difference=test_max_difference,
test_mean_pixel_difference=test_mean_pixel_difference,
)
-
- # override default test because no output_type "latent", use "pt" instead
- def test_callback_inputs(self):
- sig = inspect.signature(self.pipeline_class.__call__)
-
- if not ("callback_on_step_end_tensor_inputs" in sig.parameters and "callback_on_step_end" in sig.parameters):
- return
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- self.assertTrue(
- hasattr(pipe, "_callback_tensor_inputs"),
- f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
- )
-
- def callback_inputs_test(pipe, i, t, callback_kwargs):
- missing_callback_inputs = set()
- for v in pipe._callback_tensor_inputs:
- if v not in callback_kwargs:
- missing_callback_inputs.add(v)
- self.assertTrue(
- len(missing_callback_inputs) == 0, f"Missing callback tensor inputs: {missing_callback_inputs}"
- )
- last_i = pipe.num_timesteps - 1
- if i == last_i:
- callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
- return callback_kwargs
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["callback_on_step_end"] = callback_inputs_test
- inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
- inputs["num_inference_steps"] = 2
- inputs["output_type"] = "pt"
-
- output = pipe(**inputs)[0]
- assert output.abs().sum() == 0
diff --git a/tests/pipelines/karras_ve/__init__.py b/tests/pipelines/karras_ve/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/karras_ve/test_karras_ve.py b/tests/pipelines/karras_ve/test_karras_ve.py
new file mode 100644
index 000000000000..228d65e508c9
--- /dev/null
+++ b/tests/pipelines/karras_ve/test_karras_ve.py
@@ -0,0 +1,86 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+
+from diffusers import KarrasVePipeline, KarrasVeScheduler, UNet2DModel
+from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch, torch_device
+
+
+enable_full_determinism()
+
+
+class KarrasVePipelineFastTests(unittest.TestCase):
+ @property
+ def dummy_uncond_unet(self):
+ torch.manual_seed(0)
+ model = UNet2DModel(
+ block_out_channels=(32, 64),
+ layers_per_block=2,
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownBlock2D", "AttnDownBlock2D"),
+ up_block_types=("AttnUpBlock2D", "UpBlock2D"),
+ )
+ return model
+
+ def test_inference(self):
+ unet = self.dummy_uncond_unet
+ scheduler = KarrasVeScheduler()
+
+ pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images
+
+ generator = torch.manual_seed(0)
+ image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+ expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+
+@nightly
+@require_torch
+class KarrasVePipelineIntegrationTests(unittest.TestCase):
+ def test_inference(self):
+ model_id = "google/ncsnpp-celebahq-256"
+ model = UNet2DModel.from_pretrained(model_id)
+ scheduler = KarrasVeScheduler()
+
+ pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.manual_seed(0)
+ image = pipe(num_inference_steps=20, generator=generator, output_type="numpy").images
+
+ image_slice = image[0, -3:, -3:, -1]
+ assert image.shape == (1, 256, 256, 3)
+ expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/pipeline_params.py b/tests/pipelines/pipeline_params.py
index f5be787656c7..7c5ffa2ca24b 100644
--- a/tests/pipelines/pipeline_params.py
+++ b/tests/pipelines/pipeline_params.py
@@ -123,5 +123,3 @@
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
-
-TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
diff --git a/tests/pipelines/shap_e/test_shap_e.py b/tests/pipelines/shap_e/test_shap_e.py
index c7792f097ed5..7b95fdd9e669 100644
--- a/tests/pipelines/shap_e/test_shap_e.py
+++ b/tests/pipelines/shap_e/test_shap_e.py
@@ -160,7 +160,7 @@ def get_dummy_inputs(self, device, seed=0):
"generator": generator,
"num_inference_steps": 1,
"frame_size": 32,
- "output_type": "latent",
+ "output_type": "np",
}
return inputs
@@ -176,12 +176,24 @@ def test_shap_e(self):
output = pipe(**self.get_dummy_inputs(device))
image = output.images[0]
- image = image.cpu().numpy()
- image_slice = image[-3:, -3:]
-
- assert image.shape == (32, 16)
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (20, 32, 32, 3)
+
+ expected_slice = np.array(
+ [
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ ]
+ )
- expected_slice = np.array([-1.0000, -0.6241, 1.0000, -0.8978, -0.6866, 0.7876, -0.7473, -0.2874, 0.6103])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_batch_consistent(self):
diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py
index ee8d9d07cd77..055dbe7a97d4 100644
--- a/tests/pipelines/shap_e/test_shap_e_img2img.py
+++ b/tests/pipelines/shap_e/test_shap_e_img2img.py
@@ -181,7 +181,7 @@ def get_dummy_inputs(self, device, seed=0):
"generator": generator,
"num_inference_steps": 1,
"frame_size": 32,
- "output_type": "latent",
+ "output_type": "np",
}
return inputs
@@ -197,12 +197,22 @@ def test_shap_e(self):
output = pipe(**self.get_dummy_inputs(device))
image = output.images[0]
- image_slice = image[-3:, -3:].cpu().numpy()
+ image_slice = image[0, -3:, -3:, -1]
- assert image.shape == (32, 16)
+ assert image.shape == (20, 32, 32, 3)
expected_slice = np.array(
- [-1.0, 0.40668195, 0.57322013, -0.9469888, 0.4283227, 0.30348337, -0.81094897, 0.74555075, 0.15342723]
+ [
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ 0.00039216,
+ ]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
index 28d0d07e6948..d6a63b98912a 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
@@ -31,7 +31,6 @@
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
- LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
@@ -42,7 +41,6 @@
from diffusers.utils.testing_utils import (
CaptureLogger,
enable_full_determinism,
- load_image,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
@@ -54,12 +52,7 @@
torch_device,
)
-from ..pipeline_params import (
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
@@ -89,7 +82,6 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239])
-
assert np.abs(image_slice - expected_slice).max() < 5e-3
except Exception:
error = f"{traceback.format_exc()}"
@@ -107,21 +99,18 @@ class StableDiffusionPipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
- def get_dummy_components(self, time_cond_proj_dim=None):
+ def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=1,
+ block_out_channels=(32, 64),
+ layers_per_block=2,
sample_size=32,
- time_cond_proj_dim=time_cond_proj_dim,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
- norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
@@ -132,23 +121,22 @@ def get_dummy_components(self, time_cond_proj_dim=None):
)
torch.manual_seed(0)
vae = AutoencoderKL(
- block_out_channels=[4, 8],
+ block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
- norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
- intermediate_size=64,
+ intermediate_size=37,
layer_norm_eps=1e-05,
- num_attention_heads=8,
- num_hidden_layers=3,
+ num_attention_heads=4,
+ num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
@@ -163,7 +151,6 @@ def get_dummy_components(self, time_cond_proj_dim=None):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
@@ -196,49 +183,7 @@ def test_stable_diffusion_ddim(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3203, 0.4555, 0.4711, 0.3505, 0.3973, 0.4650, 0.5137, 0.3392, 0.4045])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3454, 0.5349, 0.5185, 0.2808, 0.4509, 0.4612, 0.4655, 0.3601, 0.4315])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_lcm_custom_timesteps(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- del inputs["num_inference_steps"]
- inputs["timesteps"] = [999, 499]
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3454, 0.5349, 0.5185, 0.2808, 0.4509, 0.4612, 0.4655, 0.3601, 0.4315])
+ expected_slice = np.array([0.5756, 0.6118, 0.5005, 0.5041, 0.5471, 0.4726, 0.4976, 0.4865, 0.4864])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -372,7 +317,7 @@ def test_stable_diffusion_ddim_factor_8(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 136, 136, 3)
- expected_slice = np.array([0.4346, 0.5621, 0.5016, 0.3926, 0.4533, 0.4134, 0.5625, 0.5632, 0.5265])
+ expected_slice = np.array([0.5524, 0.5626, 0.6069, 0.4727, 0.386, 0.3995, 0.4613, 0.4328, 0.4269])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -390,7 +335,7 @@ def test_stable_diffusion_pndm(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3411, 0.5032, 0.4704, 0.3135, 0.4323, 0.4740, 0.5150, 0.3498, 0.4022])
+ expected_slice = np.array([0.5122, 0.5712, 0.4825, 0.5053, 0.5646, 0.4769, 0.5179, 0.4894, 0.4994])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -430,7 +375,7 @@ def test_stable_diffusion_k_lms(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3149, 0.5246, 0.4796, 0.3218, 0.4469, 0.4729, 0.5151, 0.3597, 0.3954])
+ expected_slice = np.array([0.4873, 0.5443, 0.4845, 0.5004, 0.5549, 0.4850, 0.5191, 0.4941, 0.5065])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -449,7 +394,7 @@ def test_stable_diffusion_k_euler_ancestral(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3151, 0.5243, 0.4794, 0.3217, 0.4468, 0.4728, 0.5152, 0.3598, 0.3954])
+ expected_slice = np.array([0.4872, 0.5444, 0.4846, 0.5003, 0.5549, 0.4850, 0.5189, 0.4941, 0.5067])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -468,7 +413,7 @@ def test_stable_diffusion_k_euler(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3149, 0.5246, 0.4796, 0.3218, 0.4469, 0.4729, 0.5151, 0.3597, 0.3954])
+ expected_slice = np.array([0.4873, 0.5443, 0.4845, 0.5004, 0.5549, 0.4850, 0.5191, 0.4941, 0.5065])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -540,7 +485,7 @@ def test_stable_diffusion_negative_prompt(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3458, 0.5120, 0.4800, 0.3116, 0.4348, 0.4802, 0.5237, 0.3467, 0.3991])
+ expected_slice = np.array([0.5114, 0.5706, 0.4772, 0.5028, 0.5637, 0.4732, 0.5169, 0.4881, 0.4977])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -693,7 +638,7 @@ def test_stable_diffusion_1_1_pndm(self):
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.4363, 0.4355, 0.3667, 0.4066, 0.3970, 0.3866, 0.4394, 0.4356, 0.4059])
+ expected_slice = np.array([0.43625, 0.43554, 0.36670, 0.40660, 0.39703, 0.38658, 0.43936, 0.43557, 0.40592])
assert np.abs(image_slice - expected_slice).max() < 3e-3
def test_stable_diffusion_v1_4_with_freeu(self):
@@ -720,7 +665,7 @@ def test_stable_diffusion_1_4_pndm(self):
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.5740, 0.4784, 0.3162, 0.6358, 0.5831, 0.5505, 0.5082, 0.5631, 0.5575])
+ expected_slice = np.array([0.57400, 0.47841, 0.31625, 0.63583, 0.58306, 0.55056, 0.50825, 0.56306, 0.55748])
assert np.abs(image_slice - expected_slice).max() < 3e-3
def test_stable_diffusion_ddim(self):
@@ -1112,29 +1057,6 @@ def test_stable_diffusion_compile(self):
inputs["seed"] = seed
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=inputs)
- def test_stable_diffusion_lcm(self):
- unet = UNet2DConditionModel.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", subfolder="unet")
- sd_pipe = StableDiffusionPipeline.from_pretrained("Lykon/dreamshaper-7", unet=unet).to(torch_device)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- inputs["num_inference_steps"] = 6
- inputs["output_type"] = "pil"
-
- image = sd_pipe(**inputs).images[0]
-
- expected_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_full/stable_diffusion_lcm.png"
- )
-
- image = sd_pipe.image_processor.pil_to_numpy(image)
- expected_image = sd_pipe.image_processor.pil_to_numpy(expected_image)
-
- max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
-
- assert max_diff < 1e-2
-
@slow
@require_torch_gpu
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py
index a5e8649f060f..d48175a7789b 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py
@@ -19,13 +19,11 @@
import numpy as np
import torch
-from parameterized import parameterized
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import (
AutoencoderKL,
- LCMScheduler,
MultiAdapter,
PNDMScheduler,
StableDiffusionAdapterPipeline,
@@ -39,7 +37,6 @@
floats_tensor,
load_image,
load_numpy,
- numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
@@ -57,7 +54,7 @@ class AdapterTests:
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
- def get_dummy_components(self, adapter_type, time_cond_proj_dim=None):
+ def get_dummy_components(self, adapter_type):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
@@ -68,7 +65,6 @@ def get_dummy_components(self, adapter_type, time_cond_proj_dim=None):
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
- time_cond_proj_dim=time_cond_proj_dim,
)
scheduler = PNDMScheduler(skip_prk_steps=True)
torch.manual_seed(0)
@@ -141,100 +137,11 @@ def get_dummy_components(self, adapter_type, time_cond_proj_dim=None):
}
return components
- def get_dummy_components_with_full_downscaling(self, adapter_type):
- """Get dummy components with x8 VAE downscaling and 4 UNet down blocks.
- These dummy components are intended to fully-exercise the T2I-Adapter
- downscaling behavior.
- """
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 32, 32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
- up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
- cross_attention_dim=32,
- )
- scheduler = PNDMScheduler(skip_prk_steps=True)
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 32, 32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- torch.manual_seed(0)
-
- if adapter_type == "full_adapter" or adapter_type == "light_adapter":
- adapter = T2IAdapter(
- in_channels=3,
- channels=[32, 32, 32, 64],
- num_res_blocks=2,
- downscale_factor=8,
- adapter_type=adapter_type,
- )
- elif adapter_type == "multi_adapter":
- adapter = MultiAdapter(
- [
- T2IAdapter(
- in_channels=3,
- channels=[32, 32, 32, 64],
- num_res_blocks=2,
- downscale_factor=8,
- adapter_type="full_adapter",
- ),
- T2IAdapter(
- in_channels=3,
- channels=[32, 32, 32, 64],
- num_res_blocks=2,
- downscale_factor=8,
- adapter_type="full_adapter",
- ),
- ]
- )
- else:
- raise ValueError(
- f"Unknown adapter type: {adapter_type}, must be one of 'full_adapter', 'light_adapter', or 'multi_adapter''"
- )
-
- components = {
- "adapter": adapter,
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0, height=64, width=64, num_images=1):
+ def get_dummy_inputs(self, device, seed=0, num_images=1):
if num_images == 1:
- image = floats_tensor((1, 3, height, width), rng=random.Random(seed)).to(device)
+ image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
else:
- image = [
- floats_tensor((1, 3, height, width), rng=random.Random(seed)).to(device) for _ in range(num_images)
- ]
+ image = [floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) for _ in range(num_images)]
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
@@ -263,86 +170,10 @@ def test_xformers_attention_forwardGenerator_pass(self):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
- @parameterized.expand(
- [
- # (dim=264) The internal feature map will be 33x33 after initial pixel unshuffling (downscaled x8).
- (((4 * 8 + 1) * 8),),
- # (dim=272) The internal feature map will be 17x17 after the first T2I down block (downscaled x16).
- (((4 * 4 + 1) * 16),),
- # (dim=288) The internal feature map will be 9x9 after the second T2I down block (downscaled x32).
- (((4 * 2 + 1) * 32),),
- # (dim=320) The internal feature map will be 5x5 after the third T2I down block (downscaled x64).
- (((4 * 1 + 1) * 64),),
- ]
- )
- def test_multiple_image_dimensions(self, dim):
- """Test that the T2I-Adapter pipeline supports any input dimension that
- is divisible by the adapter's `downscale_factor`. This test was added in
- response to an issue where the T2I Adapter's downscaling padding
- behavior did not match the UNet's behavior.
-
- Note that we have selected `dim` values to produce odd resolutions at
- each downscaling level.
- """
- components = self.get_dummy_components_with_full_downscaling()
- sd_pipe = StableDiffusionAdapterPipeline(**components)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device, height=dim, width=dim)
- image = sd_pipe(**inputs).images
-
- assert image.shape == (1, dim, dim, 3)
-
- def test_adapter_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionAdapterPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.4535, 0.5493, 0.4359, 0.5452, 0.6086, 0.4441, 0.5544, 0.501, 0.4859])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_adapter_lcm_custom_timesteps(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionAdapterPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- del inputs["num_inference_steps"]
- inputs["timesteps"] = [999, 499]
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.4535, 0.5493, 0.4359, 0.5452, 0.6086, 0.4441, 0.5544, 0.501, 0.4859])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
- def get_dummy_components(self, time_cond_proj_dim=None):
- return super().get_dummy_components("full_adapter", time_cond_proj_dim=time_cond_proj_dim)
-
- def get_dummy_components_with_full_downscaling(self):
- return super().get_dummy_components_with_full_downscaling("full_adapter")
+ def get_dummy_components(self):
+ return super().get_dummy_components("full_adapter")
def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -361,11 +192,8 @@ def test_stable_diffusion_adapter_default_case(self):
class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
- def get_dummy_components(self, time_cond_proj_dim=None):
- return super().get_dummy_components("light_adapter", time_cond_proj_dim=time_cond_proj_dim)
-
- def get_dummy_components_with_full_downscaling(self):
- return super().get_dummy_components_with_full_downscaling("light_adapter")
+ def get_dummy_components(self):
+ return super().get_dummy_components("light_adapter")
def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -384,14 +212,11 @@ def test_stable_diffusion_adapter_default_case(self):
class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
- def get_dummy_components(self, time_cond_proj_dim=None):
- return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim)
+ def get_dummy_components(self):
+ return super().get_dummy_components("multi_adapter")
- def get_dummy_components_with_full_downscaling(self):
- return super().get_dummy_components_with_full_downscaling("multi_adapter")
-
- def get_dummy_inputs(self, device, height=64, width=64, seed=0):
- inputs = super().get_dummy_inputs(device, seed, height=height, width=width, num_images=2)
+ def get_dummy_inputs(self, device, seed=0):
+ inputs = super().get_dummy_inputs(device, seed, num_images=2)
inputs["adapter_conditioning_scale"] = [0.5, 0.5]
return inputs
@@ -598,334 +423,117 @@ def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
- def test_stable_diffusion_adapter_color(self):
- adapter_model = "TencentARC/t2iadapter_color_sd14v1"
- sd_model = "CompVis/stable-diffusion-v1-4"
- prompt = "snail"
- image_url = (
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/color.png"
- )
- input_channels = 3
- out_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_color_sd14v1.npy"
-
- image = load_image(image_url)
- expected_out = load_numpy(out_url)
- if input_channels == 1:
- image = image.convert("L")
-
- adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
-
- pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
-
- max_diff = numpy_cosine_similarity_distance(out.flatten(), expected_out.flatten())
- assert max_diff < 1e-2
-
- def test_stable_diffusion_adapter_depth(self):
- adapter_model = "TencentARC/t2iadapter_depth_sd14v1"
- sd_model = "CompVis/stable-diffusion-v1-4"
- prompt = "snail"
- image_url = (
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/color.png"
- )
- input_channels = 3
- out_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_color_sd14v1.npy"
-
- image = load_image(image_url)
- expected_out = load_numpy(out_url)
- if input_channels == 1:
- image = image.convert("L")
-
- adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
-
- pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
-
- max_diff = numpy_cosine_similarity_distance(out.flatten(), expected_out.flatten())
- assert max_diff < 1e-2
-
- def test_stable_diffusion_adapter_depth_sd_v14(self):
- adapter_model = "TencentARC/t2iadapter_depth_sd14v1"
- sd_model = "CompVis/stable-diffusion-v1-4"
- prompt = "desk"
- image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/desk_depth.png"
- input_channels = 3
- out_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_depth_sd14v1.npy"
-
- image = load_image(image_url)
- expected_out = load_numpy(out_url)
- if input_channels == 1:
- image = image.convert("L")
-
- adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
-
- pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
-
- max_diff = numpy_cosine_similarity_distance(out.flatten(), expected_out.flatten())
- assert max_diff < 1e-2
-
- def test_stable_diffusion_adapter_depth_sd_v15(self):
- adapter_model = "TencentARC/t2iadapter_depth_sd15v2"
- sd_model = "runwayml/stable-diffusion-v1-5"
- prompt = "desk"
- image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/desk_depth.png"
- input_channels = 3
- out_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_depth_sd15v2.npy"
-
- image = load_image(image_url)
- expected_out = load_numpy(out_url)
- if input_channels == 1:
- image = image.convert("L")
-
- adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
-
- pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
-
- max_diff = numpy_cosine_similarity_distance(out.flatten(), expected_out.flatten())
- assert max_diff < 1e-2
-
- def test_stable_diffusion_adapter_keypose_sd_v14(self):
- adapter_model = "TencentARC/t2iadapter_keypose_sd14v1"
- sd_model = "CompVis/stable-diffusion-v1-4"
- prompt = "person"
- image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/person_keypose.png"
- input_channels = 3
- out_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_keypose_sd14v1.npy"
-
- image = load_image(image_url)
- expected_out = load_numpy(out_url)
- if input_channels == 1:
- image = image.convert("L")
-
- adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
-
- pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
-
- max_diff = numpy_cosine_similarity_distance(out.flatten(), expected_out.flatten())
- assert max_diff < 1e-2
-
- def test_stable_diffusion_adapter_openpose_sd_v14(self):
- adapter_model = "TencentARC/t2iadapter_openpose_sd14v1"
- sd_model = "CompVis/stable-diffusion-v1-4"
- prompt = "person"
- image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/iron_man_pose.png"
- input_channels = 3
- out_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_openpose_sd14v1.npy"
-
- image = load_image(image_url)
- expected_out = load_numpy(out_url)
- if input_channels == 1:
- image = image.convert("L")
-
- adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
-
- pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
-
- max_diff = numpy_cosine_similarity_distance(out.flatten(), expected_out.flatten())
- assert max_diff < 1e-2
-
- def test_stable_diffusion_adapter_seg_sd_v14(self):
- adapter_model = "TencentARC/t2iadapter_seg_sd14v1"
- sd_model = "CompVis/stable-diffusion-v1-4"
- prompt = "motorcycle"
- image_url = (
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motor.png"
- )
- input_channels = 3
- out_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_seg_sd14v1.npy"
-
- image = load_image(image_url)
- expected_out = load_numpy(out_url)
- if input_channels == 1:
- image = image.convert("L")
-
- adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
-
- pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
-
- max_diff = numpy_cosine_similarity_distance(out.flatten(), expected_out.flatten())
- assert max_diff < 1e-2
-
- def test_stable_diffusion_adapter_zoedepth_sd_v15(self):
- adapter_model = "TencentARC/t2iadapter_zoedepth_sd15v1"
- sd_model = "runwayml/stable-diffusion-v1-5"
- prompt = "motorcycle"
- image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motorcycle.png"
- input_channels = 3
- out_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_zoedepth_sd15v1.npy"
-
- image = load_image(image_url)
- expected_out = load_numpy(out_url)
- if input_channels == 1:
- image = image.convert("L")
-
- adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
-
- pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
-
- max_diff = numpy_cosine_similarity_distance(out.flatten(), expected_out.flatten())
- assert max_diff < 1e-2
-
- def test_stable_diffusion_adapter_canny_sd_v14(self):
- adapter_model = "TencentARC/t2iadapter_canny_sd14v1"
- sd_model = "CompVis/stable-diffusion-v1-4"
- prompt = "toy"
- image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
- input_channels = 1
- out_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_canny_sd14v1.npy"
-
- image = load_image(image_url)
- expected_out = load_numpy(out_url)
- if input_channels == 1:
- image = image.convert("L")
-
- adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
-
- pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
-
- max_diff = numpy_cosine_similarity_distance(out.flatten(), expected_out.flatten())
- assert max_diff < 1e-2
-
- def test_stable_diffusion_adapter_canny_sd_v15(self):
- adapter_model = "TencentARC/t2iadapter_canny_sd15v2"
- sd_model = "runwayml/stable-diffusion-v1-5"
- prompt = "toy"
- image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
- input_channels = 1
- out_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_canny_sd15v2.npy"
-
- image = load_image(image_url)
- expected_out = load_numpy(out_url)
- if input_channels == 1:
- image = image.convert("L")
-
- adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
-
- pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
-
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
-
- max_diff = numpy_cosine_similarity_distance(out.flatten(), expected_out.flatten())
- assert max_diff < 1e-2
-
- def test_stable_diffusion_adapter_sketch_sd14(self):
- adapter_model = "TencentARC/t2iadapter_sketch_sd14v1"
- sd_model = "CompVis/stable-diffusion-v1-4"
- prompt = "cat"
- image_url = (
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/edge.png"
- )
- input_channels = 1
- out_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_sketch_sd14v1.npy"
-
- image = load_image(image_url)
- expected_out = load_numpy(out_url)
- if input_channels == 1:
- image = image.convert("L")
-
- adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
-
- pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
-
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
-
- max_diff = numpy_cosine_similarity_distance(out.flatten(), expected_out.flatten())
- assert max_diff < 1e-2
-
- def test_stable_diffusion_adapter_sketch_sd15(self):
- adapter_model = "TencentARC/t2iadapter_sketch_sd15v2"
- sd_model = "runwayml/stable-diffusion-v1-5"
- prompt = "cat"
- image_url = (
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/edge.png"
- )
- input_channels = 1
- out_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_sketch_sd15v2.npy"
+ def test_stable_diffusion_adapter(self):
+ test_cases = [
+ (
+ "TencentARC/t2iadapter_color_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "snail",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/color.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_color_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_depth_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "desk",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/desk_depth.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_depth_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_depth_sd15v2",
+ "runwayml/stable-diffusion-v1-5",
+ "desk",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/desk_depth.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_depth_sd15v2.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_keypose_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "person",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/person_keypose.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_keypose_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_openpose_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "person",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/iron_man_pose.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_openpose_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_seg_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "motorcycle",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motor.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_seg_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_zoedepth_sd15v1",
+ "runwayml/stable-diffusion-v1-5",
+ "motorcycle",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/motorcycle.png",
+ 3,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_zoedepth_sd15v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_canny_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "toy",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png",
+ 1,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_canny_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_canny_sd15v2",
+ "runwayml/stable-diffusion-v1-5",
+ "toy",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png",
+ 1,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_canny_sd15v2.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_sketch_sd14v1",
+ "CompVis/stable-diffusion-v1-4",
+ "cat",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/edge.png",
+ 1,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_sketch_sd14v1.npy",
+ ),
+ (
+ "TencentARC/t2iadapter_sketch_sd15v2",
+ "runwayml/stable-diffusion-v1-5",
+ "cat",
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/edge.png",
+ 1,
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/t2iadapter_sketch_sd15v2.npy",
+ ),
+ ]
- image = load_image(image_url)
- expected_out = load_numpy(out_url)
- if input_channels == 1:
- image = image.convert("L")
+ for adapter_model, sd_model, prompt, image_url, input_channels, out_url in test_cases:
+ image = load_image(image_url)
+ expected_out = load_numpy(out_url)
- adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
+ if input_channels == 1:
+ image = image.convert("L")
- pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
+ adapter = T2IAdapter.from_pretrained(adapter_model, torch_dtype=torch.float16)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
+ pipe = StableDiffusionAdapterPipeline.from_pretrained(sd_model, adapter=adapter, safety_checker=None)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.enable_attention_slicing()
- generator = torch.Generator(device="cpu").manual_seed(0)
+ generator = torch.Generator(device="cpu").manual_seed(0)
- out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
+ out = pipe(prompt=prompt, image=image, generator=generator, num_inference_steps=2, output_type="np").images
- max_diff = numpy_cosine_similarity_distance(out.flatten(), expected_out.flatten())
- assert max_diff < 1e-2
+ self.assertTrue(np.allclose(out, expected_out))
def test_stable_diffusion_adapter_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py
index 2e9d7c3b437b..cd688c3beb37 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py
@@ -36,6 +36,7 @@
load_numpy,
nightly,
numpy_cosine_similarity_distance,
+ print_tensor_test,
require_torch_gpu,
slow,
torch_device,
@@ -201,6 +202,7 @@ def test_stable_diffusion_img_variation_pipeline_default(self):
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.8449, 0.9079, 0.7571, 0.7873, 0.8348, 0.7010, 0.6694, 0.6873, 0.6138])
+ print_tensor_test(image_slice)
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 1e-4
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
index fb56d868f1cc..be8f067b1b78 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
@@ -24,11 +24,9 @@
from diffusers import (
AutoencoderKL,
- AutoencoderTiny,
DDIMScheduler,
DPMSolverMultistepScheduler,
HeunDiscreteScheduler,
- LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionImg2ImgPipeline,
@@ -53,7 +51,6 @@
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
- TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
@@ -102,14 +99,12 @@ class StableDiffusionImg2ImgPipelineFastTests(
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
- callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
- def get_dummy_components(self, time_cond_proj_dim=None):
+ def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
- time_cond_proj_dim=time_cond_proj_dim,
sample_size=32,
in_channels=4,
out_channels=4,
@@ -150,13 +145,9 @@ def get_dummy_components(self, time_cond_proj_dim=None):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
- def get_dummy_tiny_autoencoder(self):
- return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4)
-
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
@@ -190,42 +181,6 @@ def test_stable_diffusion_img2img_default_case(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- def test_stable_diffusion_img2img_default_case_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionImg2ImgPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 32, 32, 3)
- expected_slice = np.array([0.5709, 0.4614, 0.4587, 0.5978, 0.5298, 0.6910, 0.6240, 0.5212, 0.5454])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
-
- def test_stable_diffusion_img2img_default_case_lcm_custom_timesteps(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionImg2ImgPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- del inputs["num_inference_steps"]
- inputs["timesteps"] = [999, 499]
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 32, 32, 3)
- expected_slice = np.array([0.5709, 0.4614, 0.4587, 0.5978, 0.5298, 0.6910, 0.6240, 0.5212, 0.5454])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
-
def test_stable_diffusion_img2img_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -281,23 +236,6 @@ def test_stable_diffusion_img2img_k_lms(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- def test_stable_diffusion_img2img_tiny_autoencoder(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionImg2ImgPipeline(**components)
- sd_pipe.vae = self.get_dummy_tiny_autoencoder()
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 32, 32, 3)
- expected_slice = np.array([0.00669, 0.00669, 0.0, 0.00693, 0.00858, 0.0, 0.00567, 0.00515, 0.00125])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
-
@skip_mps
def test_save_load_local(self):
return super().test_save_load_local()
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
index a69edb869641..e485bc9123b0 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
@@ -29,7 +29,6 @@
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
- LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionInpaintPipeline,
@@ -51,11 +50,7 @@
torch_device,
)
-from ..pipeline_params import (
- TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
- TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
- TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
-)
+from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
@@ -105,13 +100,11 @@ class StableDiffusionInpaintPipelineFastTests(
image_params = frozenset([])
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
- callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"mask", "masked_image_latents"})
- def get_dummy_components(self, time_cond_proj_dim=None):
+ def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
- time_cond_proj_dim=time_cond_proj_dim,
layers_per_block=2,
sample_size=32,
in_channels=9,
@@ -153,7 +146,6 @@ def get_dummy_components(self, time_cond_proj_dim=None):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
@@ -209,42 +201,6 @@ def test_stable_diffusion_inpaint(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- def test_stable_diffusion_inpaint_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionInpaintPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.4931, 0.5988, 0.4569, 0.5556, 0.6650, 0.5087, 0.5966, 0.5358, 0.5269])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_inpaint_lcm_custom_timesteps(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionInpaintPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- del inputs["num_inference_steps"]
- inputs["timesteps"] = [999, 499]
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.4931, 0.5988, 0.4569, 0.5556, 0.6650, 0.5087, 0.5966, 0.5358, 0.5269])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
def test_stable_diffusion_inpaint_image_tensor(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -327,12 +283,11 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli
image_params = frozenset([])
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
- def get_dummy_components(self, time_cond_proj_dim=None):
+ def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
- time_cond_proj_dim=time_cond_proj_dim,
sample_size=32,
in_channels=4,
out_channels=4,
@@ -373,7 +328,6 @@ def get_dummy_components(self, time_cond_proj_dim=None):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
@@ -422,42 +376,6 @@ def test_stable_diffusion_inpaint(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- def test_stable_diffusion_inpaint_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionInpaintPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.6240, 0.5355, 0.5649, 0.5378, 0.5374, 0.6242, 0.5132, 0.5347, 0.5396])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_inpaint_lcm_custom_timesteps(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionInpaintPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- del inputs["num_inference_steps"]
- inputs["timesteps"] = [999, 499]
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.6240, 0.5355, 0.5649, 0.5378, 0.5374, 0.6242, 0.5132, 0.5347, 0.5396])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
def test_stable_diffusion_inpaint_2_images(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
index 69b36cb3bb8a..07fd8e1b5192 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
@@ -45,7 +45,6 @@
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
- TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
@@ -61,7 +60,6 @@ class StableDiffusionInstructPix2PixPipelineFastTests(
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
- callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"image_latents"}) - {"negative_prompt_embeds"}
def get_dummy_components(self):
torch.manual_seed(0)
@@ -234,34 +232,6 @@ def test_latents_input(self):
max_diff = np.abs(out - out_latents_inputs).max()
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
- # Override the default test_callback_cfg because pix2pix create inputs for cfg differently
- def test_callback_cfg(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- def callback_no_cfg(pipe, i, t, callback_kwargs):
- if i == 1:
- for k, w in callback_kwargs.items():
- if k in self.callback_cfg_params:
- callback_kwargs[k] = callback_kwargs[k].chunk(3)[0]
- pipe._guidance_scale = 1.0
-
- return callback_kwargs
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["guidance_scale"] = 1.0
- inputs["num_inference_steps"] = 2
- out_no_cfg = pipe(**inputs)[0]
-
- inputs["guidance_scale"] = 7.5
- inputs["callback_on_step_end"] = callback_no_cfg
- inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
- out_callback_no_cfg = pipe(**inputs)[0]
-
- assert out_no_cfg.shape == out_callback_no_cfg.shape
-
@slow
@require_torch_gpu
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
index ed295f792f99..a0e66c45b5a1 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
@@ -43,12 +43,7 @@
torch_device,
)
-from ..pipeline_params import (
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
@@ -63,7 +58,6 @@ class StableDiffusion2PipelineFastTests(
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
@@ -123,7 +117,6 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py
index 5cf8b38d4da1..149c90698f1c 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py
@@ -56,7 +56,6 @@
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
- TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
@@ -76,7 +75,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"depth_mask"})
def get_dummy_components(self):
torch.manual_seed(0)
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
index 41b9f83914a6..1e726b95960f 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
@@ -33,11 +33,7 @@
torch_device,
)
-from ..pipeline_params import (
- TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
- TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
- TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
-)
+from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
@@ -54,7 +50,6 @@ class StableDiffusion2InpaintPipelineFastTests(
[]
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
- callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"mask", "masked_image_latents"})
def get_dummy_components(self):
torch.manual_seed(0)
@@ -108,7 +103,6 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
- "image_encoder": None,
}
return components
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py
index 09034789c61c..4d6bd85d981a 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py
@@ -127,7 +127,6 @@ def test_stable_diffusion_v_pred_ddim(self):
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=None,
- image_encoder=None,
requires_safety_checker=False,
)
sd_pipe = sd_pipe.to(device)
@@ -177,7 +176,6 @@ def test_stable_diffusion_v_pred_k_euler(self):
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=None,
- image_encoder=None,
requires_safety_checker=False,
)
sd_pipe = sd_pipe.to(device)
@@ -238,7 +236,6 @@ def test_stable_diffusion_v_pred_fp16(self):
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=None,
- image_encoder=None,
requires_safety_checker=False,
)
sd_pipe = sd_pipe.to(torch_device)
@@ -370,9 +367,9 @@ def test_stable_diffusion_attention_slicing_v_pred(self):
output = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")
image = output.images
- # make sure that more than 3.0 GB is allocated
+ # make sure that more than 5.5 GB is allocated
mem_bytes = torch.cuda.max_memory_allocated()
- assert mem_bytes > 3 * 10**9
+ assert mem_bytes > 5.5 * 10**9
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_chunked.flatten())
assert max_diff < 1e-3
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
index 59f0c0151d3a..cebd860a4379 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
@@ -27,49 +27,32 @@
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
- LCMScheduler,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
UniPCMultistepScheduler,
)
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- load_image,
- numpy_cosine_similarity_distance,
- require_torch_gpu,
- slow,
- torch_device,
-)
+from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device
-from ..pipeline_params import (
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
-from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism()
-class StableDiffusionXLPipelineFastTests(
- PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
-):
+class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionXLPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
- def get_dummy_components(self, time_cond_proj_dim=None):
+ def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(2, 4),
layers_per_block=2,
- time_cond_proj_dim=time_cond_proj_dim,
sample_size=32,
in_channels=4,
out_channels=4,
@@ -131,8 +114,8 @@ def get_dummy_components(self, time_cond_proj_dim=None):
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
- "image_encoder": None,
- "feature_extractor": None,
+ # "safety_checker": None,
+ # "feature_extractor": None,
}
return components
@@ -166,42 +149,6 @@ def test_stable_diffusion_xl_euler(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- def test_stable_diffusion_xl_euler_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionXLPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.4917, 0.6555, 0.4348, 0.5219, 0.7324, 0.4855, 0.5168, 0.5447, 0.5156])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_xl_euler_lcm_custom_timesteps(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionXLPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- del inputs["num_inference_steps"]
- inputs["timesteps"] = [999, 499]
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.4917, 0.6555, 0.4348, 0.5219, 0.7324, 0.4855, 0.5168, 0.5447, 0.5156])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
@@ -286,9 +233,6 @@ def test_attention_slicing_forward_pass(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
-
@require_torch_gpu
def test_stable_diffusion_xl_offloads(self):
pipes = []
@@ -354,107 +298,6 @@ def test_stable_diffusion_xl_img2img_prompt_embeds_only(self):
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
- def test_stable_diffusion_two_xl_mixture_of_denoiser_fast(self):
- components = self.get_dummy_components()
- pipe_1 = StableDiffusionXLPipeline(**components).to(torch_device)
- pipe_1.unet.set_default_attn_processor()
- pipe_2 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device)
- pipe_2.unet.set_default_attn_processor()
-
- def assert_run_mixture(
- num_steps,
- split,
- scheduler_cls_orig,
- expected_tss,
- num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps,
- ):
- inputs = self.get_dummy_inputs(torch_device)
- inputs["num_inference_steps"] = num_steps
-
- class scheduler_cls(scheduler_cls_orig):
- pass
-
- pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
- pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
-
- # Let's retrieve the number of timesteps we want to use
- pipe_1.scheduler.set_timesteps(num_steps)
- expected_steps = pipe_1.scheduler.timesteps.tolist()
-
- if pipe_1.scheduler.order == 2:
- expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
- expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split, expected_tss))
- expected_steps = expected_steps_1 + expected_steps_2
- else:
- expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
- expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))
-
- # now we monkey patch step `done_steps`
- # list into the step function for testing
- done_steps = []
- old_step = copy.copy(scheduler_cls.step)
-
- def new_step(self, *args, **kwargs):
- done_steps.append(args[1].cpu().item()) # args[1] is always the passed `t`
- return old_step(self, *args, **kwargs)
-
- scheduler_cls.step = new_step
-
- inputs_1 = {
- **inputs,
- **{
- "denoising_end": 1.0 - (split / num_train_timesteps),
- "output_type": "latent",
- },
- }
- latents = pipe_1(**inputs_1).images[0]
-
- assert expected_steps_1 == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
-
- inputs_2 = {
- **inputs,
- **{
- "denoising_start": 1.0 - (split / num_train_timesteps),
- "image": latents,
- },
- }
- pipe_2(**inputs_2).images[0]
-
- assert expected_steps_2 == done_steps[len(expected_steps_1) :]
- assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
-
- steps = 10
- for split in [300, 700]:
- for scheduler_cls_timesteps in [
- (EulerDiscreteScheduler, [901, 801, 701, 601, 501, 401, 301, 201, 101, 1]),
- (
- HeunDiscreteScheduler,
- [
- 901.0,
- 801.0,
- 801.0,
- 701.0,
- 701.0,
- 601.0,
- 601.0,
- 501.0,
- 501.0,
- 401.0,
- 401.0,
- 301.0,
- 301.0,
- 201.0,
- 201.0,
- 101.0,
- 101.0,
- 1.0,
- 1.0,
- ],
- ),
- ]:
- assert_run_mixture(steps, split, scheduler_cls_timesteps[0], scheduler_cls_timesteps[1])
-
- @slow
def test_stable_diffusion_two_xl_mixture_of_denoiser(self):
components = self.get_dummy_components()
pipe_1 = StableDiffusionXLPipeline(**components).to(torch_device)
@@ -482,13 +325,8 @@ class scheduler_cls(scheduler_cls_orig):
pipe_1.scheduler.set_timesteps(num_steps)
expected_steps = pipe_1.scheduler.timesteps.tolist()
- if pipe_1.scheduler.order == 2:
- expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
- expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split, expected_tss))
- expected_steps = expected_steps_1 + expected_steps_2
- else:
- expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
- expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))
+ expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
+ expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))
# now we monkey patch step `done_steps`
# list into the step function for testing
@@ -738,7 +576,6 @@ def new_step(self, *args, **kwargs):
]:
assert_run_mixture(steps, split, scheduler_cls_timesteps[0], scheduler_cls_timesteps[1])
- @slow
def test_stable_diffusion_three_xl_mixture_of_denoiser(self):
components = self.get_dummy_components()
pipe_1 = StableDiffusionXLPipeline(**components).to(torch_device)
@@ -771,18 +608,13 @@ class scheduler_cls(scheduler_cls_orig):
split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1))
split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2))
+ expected_steps_1 = expected_steps[:split_1_ts]
+ expected_steps_2 = expected_steps[split_1_ts:split_2_ts]
+ expected_steps_3 = expected_steps[split_2_ts:]
- if pipe_1.scheduler.order == 2:
- expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
- expected_steps_2 = expected_steps_1[-1:] + list(
- filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)
- )
- expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps))
- expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3
- else:
- expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
- expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
- expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
+ expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
+ expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
+ expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
# now we monkey patch step `done_steps`
# list into the step function for testing
@@ -937,32 +769,3 @@ def test_stable_diffusion_xl_save_from_pretrained(self):
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
-
-
-@slow
-class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
- def test_stable_diffusion_lcm(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel.from_pretrained(
- "latent-consistency/lcm-ssd-1b", torch_dtype=torch.float16, variant="fp16"
- )
- sd_pipe = StableDiffusionXLPipeline.from_pretrained(
- "segmind/SSD-1B", unet=unet, torch_dtype=torch.float16, variant="fp16"
- ).to(torch_device)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "a red car standing on the side of the street"
-
- image = sd_pipe(prompt, num_inference_steps=4, guidance_scale=8.0).images[0]
-
- expected_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_full/stable_diffusion_ssd_1b_lcm.png"
- )
-
- image = sd_pipe.image_processor.pil_to_numpy(image)
- expected_image = sd_pipe.image_processor.pil_to_numpy(expected_image)
-
- max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
-
- assert max_diff < 1e-2
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
index f63ee8be1dd0..92c22ca2c34c 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
import random
import unittest
@@ -26,41 +25,27 @@
from diffusers import (
AutoencoderKL,
EulerDiscreteScheduler,
- LCMScheduler,
MultiAdapter,
StableDiffusionXLAdapterPipeline,
T2IAdapter,
UNet2DConditionModel,
)
-from diffusers.utils import load_image, logging
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- floats_tensor,
- numpy_cosine_similarity_distance,
- require_torch_gpu,
- slow,
- torch_device,
-)
+from diffusers.utils import logging
+from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
-from ..test_pipelines_common import (
- PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
- assert_mean_pixel_difference,
-)
+from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
enable_full_determinism()
-class StableDiffusionXLAdapterPipelineFastTests(
- PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
-):
+class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionXLAdapterPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
- def get_dummy_components(self, adapter_type="full_adapter_xl", time_cond_proj_dim=None):
+ def get_dummy_components(self, adapter_type="full_adapter_xl"):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
@@ -78,7 +63,6 @@ def get_dummy_components(self, adapter_type="full_adapter_xl", time_cond_proj_di
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
- time_cond_proj_dim=time_cond_proj_dim,
)
scheduler = EulerDiscreteScheduler(
beta_start=0.00085,
@@ -163,119 +147,11 @@ def get_dummy_components(self, adapter_type="full_adapter_xl", time_cond_proj_di
}
return components
- def get_dummy_components_with_full_downscaling(self, adapter_type="full_adapter_xl"):
- """Get dummy components with x8 VAE downscaling and 3 UNet down blocks.
- These dummy components are intended to fully-exercise the T2I-Adapter
- downscaling behavior.
- """
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
- # SD2-specific config below
- attention_head_dim=2,
- use_linear_projection=True,
- addition_embed_type="text_time",
- addition_time_embed_dim=8,
- transformer_layers_per_block=1,
- projection_class_embeddings_input_dim=80, # 6 * 8 + 32
- cross_attention_dim=64,
- )
- scheduler = EulerDiscreteScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- steps_offset=1,
- beta_schedule="scaled_linear",
- timestep_spacing="leading",
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 32, 32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=32,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- if adapter_type == "full_adapter_xl":
- adapter = T2IAdapter(
- in_channels=3,
- channels=[32, 32, 64],
- num_res_blocks=2,
- downscale_factor=16,
- adapter_type=adapter_type,
- )
- elif adapter_type == "multi_adapter":
- adapter = MultiAdapter(
- [
- T2IAdapter(
- in_channels=3,
- channels=[32, 32, 64],
- num_res_blocks=2,
- downscale_factor=16,
- adapter_type="full_adapter_xl",
- ),
- T2IAdapter(
- in_channels=3,
- channels=[32, 32, 64],
- num_res_blocks=2,
- downscale_factor=16,
- adapter_type="full_adapter_xl",
- ),
- ]
- )
- else:
- raise ValueError(
- f"Unknown adapter type: {adapter_type}, must be one of 'full_adapter_xl', or 'multi_adapter''"
- )
-
- components = {
- "adapter": adapter,
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "text_encoder_2": text_encoder_2,
- "tokenizer_2": tokenizer_2,
- # "safety_checker": None,
- # "feature_extractor": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0, height=64, width=64, num_images=1):
+ def get_dummy_inputs(self, device, seed=0, num_images=1):
if num_images == 1:
- image = floats_tensor((1, 3, height, width), rng=random.Random(seed)).to(device)
+ image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
else:
- image = [
- floats_tensor((1, 3, height, width), rng=random.Random(seed)).to(device) for _ in range(num_images)
- ]
+ image = [floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) for _ in range(num_images)]
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
@@ -308,33 +184,6 @@ def test_stable_diffusion_adapter_default_case(self):
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
- @parameterized.expand(
- [
- # (dim=144) The internal feature map will be 9x9 after initial pixel unshuffling (downscaled x16).
- (((4 * 2 + 1) * 16),),
- # (dim=160) The internal feature map will be 5x5 after the first T2I down block (downscaled x32).
- (((4 * 1 + 1) * 32),),
- ]
- )
- def test_multiple_image_dimensions(self, dim):
- """Test that the T2I-Adapter pipeline supports any input dimension that
- is divisible by the adapter's `downscale_factor`. This test was added in
- response to an issue where the T2I Adapter's downscaling padding
- behavior did not match the UNet's behavior.
-
- Note that we have selected `dim` values to produce odd resolutions at
- each downscaling level.
- """
- components = self.get_dummy_components_with_full_downscaling()
- sd_pipe = StableDiffusionXLAdapterPipeline(**components)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device, height=dim, width=dim)
- image = sd_pipe(**inputs).images
-
- assert image.shape == (1, dim, dim, 3)
-
@parameterized.expand(["full_adapter", "full_adapter_xl", "light_adapter"])
def test_total_downscale_factor(self, adapter_type):
"""Test that the T2IAdapter correctly reports its total_downscale_factor."""
@@ -366,63 +215,15 @@ def test_total_downscale_factor(self, adapter_type):
expected_out_image_size,
)
- def test_save_load_optional_components(self):
- return self._test_save_load_optional_components()
-
- def test_adapter_sdxl_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionXLAdapterPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5425, 0.5385, 0.4964, 0.5045, 0.6149, 0.4974, 0.5469, 0.5332, 0.5426])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_adapter_sdxl_lcm_custom_timesteps(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionXLAdapterPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- del inputs["num_inference_steps"]
- inputs["timesteps"] = [999, 499]
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5425, 0.5385, 0.4964, 0.5045, 0.6149, 0.4974, 0.5469, 0.5332, 0.5426])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
class StableDiffusionXLMultiAdapterPipelineFastTests(
StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase
):
- def get_dummy_components(self, time_cond_proj_dim=None):
- return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim)
-
- def get_dummy_components_with_full_downscaling(self):
- return super().get_dummy_components_with_full_downscaling("multi_adapter")
+ def get_dummy_components(self):
+ return super().get_dummy_components("multi_adapter")
- def get_dummy_inputs(self, device, seed=0, height=64, width=64):
- inputs = super().get_dummy_inputs(device, seed, height, width, num_images=2)
+ def get_dummy_inputs(self, device, seed=0):
+ inputs = super().get_dummy_inputs(device, seed, num_images=2)
inputs["adapter_conditioning_scale"] = [0.5, 0.5]
return inputs
@@ -612,90 +413,3 @@ def test_inference_batch_single_identical(
if test_mean_pixel_difference:
assert_mean_pixel_difference(output_batch[0][0], output[0][0])
-
- def test_adapter_sdxl_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionXLAdapterPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448])
-
- debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()]
- print(",".join(debug))
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_adapter_sdxl_lcm_custom_timesteps(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionXLAdapterPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- del inputs["num_inference_steps"]
- inputs["timesteps"] = [999, 499]
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448])
-
- debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()]
- print(",".join(debug))
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
-
-@slow
-@require_torch_gpu
-class AdapterSDXLPipelineSlowTests(unittest.TestCase):
- def tearDown(self):
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def test_canny_lora(self):
- adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16).to(
- "cpu"
- )
- pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- adapter=adapter,
- torch_dtype=torch.float16,
- variant="fp16",
- )
- pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
- pipe.enable_sequential_cpu_offload()
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "toy"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
- )
-
- images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
-
- assert images[0].shape == (768, 512, 3)
-
- original_image = images[0, -3:, -3:, -1].flatten()
- expected_image = np.array(
- [0.50346327, 0.50708383, 0.50719553, 0.5135172, 0.5155377, 0.5066059, 0.49680984, 0.5005894, 0.48509413]
- )
- assert numpy_cosine_similarity_distance(original_image, expected_image) < 1e-4
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
index 7cad3fff0a47..ba7d3e8be30f 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
@@ -18,21 +18,11 @@
import numpy as np
import torch
-from transformers import (
- CLIPImageProcessor,
- CLIPTextConfig,
- CLIPTextModel,
- CLIPTextModelWithProjection,
- CLIPTokenizer,
- CLIPVisionConfig,
- CLIPVisionModelWithProjection,
-)
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import (
AutoencoderKL,
- AutoencoderTiny,
EulerDiscreteScheduler,
- LCMScheduler,
StableDiffusionXLImg2ImgPipeline,
UNet2DConditionModel,
)
@@ -47,9 +37,8 @@
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
- TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
)
-from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin
+from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism()
@@ -62,11 +51,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
- callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union(
- {"add_text_embeds", "add_time_ids", "add_neg_time_ids"}
- )
- def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
+ def get_dummy_components(self, skip_first_text_encoder=False):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
@@ -74,7 +60,6 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim
sample_size=32,
in_channels=4,
out_channels=4,
- time_cond_proj_dim=time_cond_proj_dim,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
@@ -103,31 +88,6 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim
latent_channels=4,
sample_size=128,
)
- torch.manual_seed(0)
- image_encoder_config = CLIPVisionConfig(
- hidden_size=32,
- image_size=224,
- projection_dim=32,
- intermediate_size=37,
- num_attention_heads=4,
- num_channels=3,
- num_hidden_layers=5,
- patch_size=14,
- )
-
- image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
-
- feature_extractor = CLIPImageProcessor(
- crop_size=224,
- do_center_crop=True,
- do_normalize=True,
- do_resize=True,
- image_mean=[0.48145466, 0.4578275, 0.40821073],
- image_std=[0.26862954, 0.26130258, 0.27577711],
- resample=3,
- size=224,
- )
-
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
@@ -158,14 +118,9 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"requires_aesthetics_score": True,
- "image_encoder": image_encoder,
- "feature_extractor": feature_extractor,
}
return components
- def get_dummy_tiny_autoencoder(self):
- return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4)
-
def test_components_function(self):
init_components = self.get_dummy_components()
init_components.pop("requires_aesthetics_score")
@@ -209,44 +164,6 @@ def test_stable_diffusion_xl_img2img_euler(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- def test_stable_diffusion_xl_img2img_euler_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 32, 32, 3)
-
- expected_slice = np.array([0.5604, 0.4352, 0.4717, 0.5844, 0.5101, 0.6704, 0.6290, 0.5460, 0.5286])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_xl_img2img_euler_lcm_custom_timesteps(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- del inputs["num_inference_steps"]
- inputs["timesteps"] = [999, 499]
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 32, 32, 3)
-
- expected_slice = np.array([0.5604, 0.4352, 0.4717, 0.5844, 0.5101, 0.6704, 0.6290, 0.5460, 0.5286])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
@@ -299,23 +216,6 @@ def test_stable_diffusion_xl_img2img_negative_prompt_embeds(self):
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
- def test_stable_diffusion_xl_img2img_tiny_autoencoder(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
- sd_pipe.vae = self.get_dummy_tiny_autoencoder()
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 32, 32, 3)
- expected_slice = np.array([0.0, 0.0, 0.0106, 0.0, 0.0, 0.0087, 0.0052, 0.0062, 0.0177])
-
- assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
-
@require_torch_gpu
def test_stable_diffusion_xl_offloads(self):
pipes = []
@@ -441,7 +341,7 @@ def test_stable_diffusion_xl_img2img_negative_conditions(self):
class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
- PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
+ PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
@@ -513,8 +413,6 @@ def get_dummy_components(self):
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"requires_aesthetics_score": True,
- "image_encoder": None,
- "feature_extractor": None,
}
return components
@@ -702,6 +600,3 @@ def test_attention_slicing_forward_pass(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
-
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
index 4a2798b3edf4..7e3698d8ca16 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
@@ -20,15 +20,7 @@
import numpy as np
import torch
from PIL import Image
-from transformers import (
- CLIPImageProcessor,
- CLIPTextConfig,
- CLIPTextModel,
- CLIPTextModelWithProjection,
- CLIPTokenizer,
- CLIPVisionConfig,
- CLIPVisionModelWithProjection,
-)
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import (
AutoencoderKL,
@@ -36,18 +28,13 @@
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
- LCMScheduler,
StableDiffusionXLInpaintPipeline,
UNet2DConditionModel,
UniPCMultistepScheduler,
)
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, require_torch_gpu, slow, torch_device
+from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, require_torch_gpu, torch_device
-from ..pipeline_params import (
- TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
- TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
- TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
-)
+from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
@@ -61,16 +48,8 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
image_params = frozenset([])
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
image_latents_params = frozenset([])
- callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union(
- {
- "add_text_embeds",
- "add_time_ids",
- "mask",
- "masked_image_latents",
- }
- )
- def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
+ def get_dummy_components(self, skip_first_text_encoder=False):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
@@ -78,7 +57,6 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim
sample_size=32,
in_channels=4,
out_channels=4,
- time_cond_proj_dim=time_cond_proj_dim,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
@@ -128,31 +106,6 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- torch.manual_seed(0)
- image_encoder_config = CLIPVisionConfig(
- hidden_size=32,
- image_size=224,
- projection_dim=32,
- intermediate_size=37,
- num_attention_heads=4,
- num_channels=3,
- num_hidden_layers=5,
- patch_size=14,
- )
-
- image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
-
- feature_extractor = CLIPImageProcessor(
- crop_size=224,
- do_center_crop=True,
- do_normalize=True,
- do_resize=True,
- image_mean=[0.48145466, 0.4578275, 0.40821073],
- image_std=[0.26862954, 0.26130258, 0.27577711],
- resample=3,
- size=224,
- )
-
components = {
"unet": unet,
"scheduler": scheduler,
@@ -161,8 +114,6 @@ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim
"tokenizer": tokenizer if not skip_first_text_encoder else None,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
- "image_encoder": image_encoder,
- "feature_extractor": feature_extractor,
"requires_aesthetics_score": True,
}
return components
@@ -246,44 +197,6 @@ def test_stable_diffusion_xl_inpaint_euler(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- def test_stable_diffusion_xl_inpaint_euler_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionXLInpaintPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6611, 0.5569, 0.5531, 0.5471, 0.5918, 0.6393, 0.5074, 0.5468, 0.5185])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_xl_inpaint_euler_lcm_custom_timesteps(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components(time_cond_proj_dim=256)
- sd_pipe = StableDiffusionXLInpaintPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- del inputs["num_inference_steps"]
- inputs["timesteps"] = [999, 499]
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6611, 0.5569, 0.5531, 0.5471, 0.5918, 0.6393, 0.5074, 0.5468, 0.5185])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
@@ -381,66 +294,6 @@ def test_stable_diffusion_xl_refiner(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- def test_stable_diffusion_two_xl_mixture_of_denoiser_fast(self):
- components = self.get_dummy_components()
- pipe_1 = StableDiffusionXLInpaintPipeline(**components).to(torch_device)
- pipe_1.unet.set_default_attn_processor()
- pipe_2 = StableDiffusionXLInpaintPipeline(**components).to(torch_device)
- pipe_2.unet.set_default_attn_processor()
-
- def assert_run_mixture(
- num_steps, split, scheduler_cls_orig, num_train_timesteps=pipe_1.scheduler.config.num_train_timesteps
- ):
- inputs = self.get_dummy_inputs(torch_device)
- inputs["num_inference_steps"] = num_steps
-
- class scheduler_cls(scheduler_cls_orig):
- pass
-
- pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
- pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
-
- # Let's retrieve the number of timesteps we want to use
- pipe_1.scheduler.set_timesteps(num_steps)
- expected_steps = pipe_1.scheduler.timesteps.tolist()
-
- split_ts = num_train_timesteps - int(round(num_train_timesteps * split))
-
- if pipe_1.scheduler.order == 2:
- expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
- expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split_ts, expected_steps))
- expected_steps = expected_steps_1 + expected_steps_2
- else:
- expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
- expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps))
-
- # now we monkey patch step `done_steps`
- # list into the step function for testing
- done_steps = []
- old_step = copy.copy(scheduler_cls.step)
-
- def new_step(self, *args, **kwargs):
- done_steps.append(args[1].cpu().item()) # args[1] is always the passed `t`
- return old_step(self, *args, **kwargs)
-
- scheduler_cls.step = new_step
-
- inputs_1 = {**inputs, **{"denoising_end": split, "output_type": "latent"}}
- latents = pipe_1(**inputs_1).images[0]
-
- assert expected_steps_1 == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
-
- inputs_2 = {**inputs, **{"denoising_start": split, "image": latents}}
- pipe_2(**inputs_2).images[0]
-
- assert expected_steps_2 == done_steps[len(expected_steps_1) :]
- assert expected_steps == done_steps, f"Failure with {scheduler_cls.__name__} and {num_steps} and {split}"
-
- for steps in [7, 20]:
- assert_run_mixture(steps, 0.33, EulerDiscreteScheduler)
- assert_run_mixture(steps, 0.33, HeunDiscreteScheduler)
-
- @slow
def test_stable_diffusion_two_xl_mixture_of_denoiser(self):
components = self.get_dummy_components()
pipe_1 = StableDiffusionXLInpaintPipeline(**components).to(torch_device)
@@ -465,14 +318,11 @@ class scheduler_cls(scheduler_cls_orig):
expected_steps = pipe_1.scheduler.timesteps.tolist()
split_ts = num_train_timesteps - int(round(num_train_timesteps * split))
+ expected_steps_1 = expected_steps[:split_ts]
+ expected_steps_2 = expected_steps[split_ts:]
- if pipe_1.scheduler.order == 2:
- expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
- expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split_ts, expected_steps))
- expected_steps = expected_steps_1 + expected_steps_2
- else:
- expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
- expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps))
+ expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
+ expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps))
# now we monkey patch step `done_steps`
# list into the step function for testing
@@ -507,7 +357,6 @@ def new_step(self, *args, **kwargs):
]:
assert_run_mixture(steps, split, scheduler_cls)
- @slow
def test_stable_diffusion_three_xl_mixture_of_denoiser(self):
components = self.get_dummy_components()
pipe_1 = StableDiffusionXLInpaintPipeline(**components).to(torch_device)
@@ -540,18 +389,13 @@ class scheduler_cls(scheduler_cls_orig):
split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1))
split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2))
+ expected_steps_1 = expected_steps[:split_1_ts]
+ expected_steps_2 = expected_steps[split_1_ts:split_2_ts]
+ expected_steps_3 = expected_steps[split_2_ts:]
- if pipe_1.scheduler.order == 2:
- expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
- expected_steps_2 = expected_steps_1[-1:] + list(
- filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)
- )
- expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps))
- expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3
- else:
- expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
- expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
- expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
+ expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
+ expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
+ expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
# now we monkey patch step `done_steps`
# list into the step function for testing
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py
index e20f8a0b54db..ca4017d11b79 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py
@@ -36,23 +36,14 @@
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
-from ..test_pipelines_common import (
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
-)
+from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism()
class StableDiffusionXLInstructPix2PixPipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- SDXLOptionalComponentsTesterMixin,
- unittest.TestCase,
+ PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionXLInstructPix2PixPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"}
@@ -184,6 +175,3 @@ def test_latents_input(self):
def test_cfg(self):
pass
-
- def test_save_load_optional_components(self):
- self._test_save_load_optional_components()
diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py
index 32ae81ddc2d8..13861b581c9b 100644
--- a/tests/pipelines/test_pipelines.py
+++ b/tests/pipelines/test_pipelines.py
@@ -14,6 +14,7 @@
# limitations under the License.
import gc
+import glob
import json
import os
import random
@@ -56,7 +57,7 @@
UniPCMultistepScheduler,
logging,
)
-from diffusers.pipelines.pipeline_utils import _get_pipeline_class
+from diffusers.pipelines.pipeline_utils import _get_pipeline_class, variant_compatible_siblings
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import (
CONFIG_NAME,
@@ -792,54 +793,6 @@ def test_text_inversion_download(self):
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
- def test_text_inversion_multi_tokens(self):
- pipe1 = StableDiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
- )
- pipe1 = pipe1.to(torch_device)
-
- token1, token2 = "<*>", "<**>"
- ten1 = torch.ones((32,))
- ten2 = torch.ones((32,)) * 2
-
- num_tokens = len(pipe1.tokenizer)
-
- pipe1.load_textual_inversion(ten1, token=token1)
- pipe1.load_textual_inversion(ten2, token=token2)
- emb1 = pipe1.text_encoder.get_input_embeddings().weight
-
- pipe2 = StableDiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
- )
- pipe2 = pipe2.to(torch_device)
- pipe2.load_textual_inversion([ten1, ten2], token=[token1, token2])
- emb2 = pipe2.text_encoder.get_input_embeddings().weight
-
- pipe3 = StableDiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
- )
- pipe3 = pipe3.to(torch_device)
- pipe3.load_textual_inversion(torch.stack([ten1, ten2], dim=0), token=[token1, token2])
- emb3 = pipe3.text_encoder.get_input_embeddings().weight
-
- assert len(pipe1.tokenizer) == len(pipe2.tokenizer) == len(pipe3.tokenizer) == num_tokens + 2
- assert (
- pipe1.tokenizer.convert_tokens_to_ids(token1)
- == pipe2.tokenizer.convert_tokens_to_ids(token1)
- == pipe3.tokenizer.convert_tokens_to_ids(token1)
- == num_tokens
- )
- assert (
- pipe1.tokenizer.convert_tokens_to_ids(token2)
- == pipe2.tokenizer.convert_tokens_to_ids(token2)
- == pipe3.tokenizer.convert_tokens_to_ids(token2)
- == num_tokens + 1
- )
- assert emb1[num_tokens].sum().item() == emb2[num_tokens].sum().item() == emb3[num_tokens].sum().item()
- assert (
- emb1[num_tokens + 1].sum().item() == emb2[num_tokens + 1].sum().item() == emb3[num_tokens + 1].sum().item()
- )
-
def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -909,58 +862,6 @@ def test_run_custom_pipeline(self):
# compare output to https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L102
assert output_str == "This is a test"
- def test_remote_components(self):
- # make sure that trust remote code has to be passed
- with self.assertRaises(ValueError):
- pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-components")
-
- # Check that only loading custom componets "my_unet", "my_scheduler" works
- pipeline = DiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-sdxl-custom-components", trust_remote_code=True
- )
-
- assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
- assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
- assert pipeline.__class__.__name__ == "StableDiffusionXLPipeline"
-
- pipeline = pipeline.to(torch_device)
- images = pipeline("test", num_inference_steps=2, output_type="np")[0]
-
- assert images.shape == (1, 64, 64, 3)
-
- # Check that only loading custom componets "my_unet", "my_scheduler" and explicit custom pipeline works
- pipeline = DiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-sdxl-custom-components", custom_pipeline="my_pipeline", trust_remote_code=True
- )
-
- assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
- assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
- assert pipeline.__class__.__name__ == "MyPipeline"
-
- pipeline = pipeline.to(torch_device)
- images = pipeline("test", num_inference_steps=2, output_type="np")[0]
-
- assert images.shape == (1, 64, 64, 3)
-
- def test_remote_auto_custom_pipe(self):
- # make sure that trust remote code has to be passed
- with self.assertRaises(ValueError):
- pipeline = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-custom-all")
-
- # Check that only loading custom componets "my_unet", "my_scheduler" and auto custom pipeline works
- pipeline = DiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-sdxl-custom-all", trust_remote_code=True
- )
-
- assert pipeline.config.unet == ("diffusers_modules.local.my_unet_model", "MyUNetModel")
- assert pipeline.config.scheduler == ("diffusers_modules.local.my_scheduler", "MyScheduler")
- assert pipeline.__class__.__name__ == "MyPipeline"
-
- pipeline = pipeline.to(torch_device)
- images = pipeline("test", num_inference_steps=2, output_type="np")[0]
-
- assert images.shape == (1, 64, 64, 3)
-
def test_local_custom_pipeline_repo(self):
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
pipeline = DiffusionPipeline.from_pretrained(
@@ -1184,8 +1085,8 @@ def test_stable_diffusion_components(self):
safety_checker=None,
feature_extractor=self.dummy_extractor,
).to(torch_device)
- img2img = StableDiffusionImg2ImgPipeline(**inpaint.components, image_encoder=None).to(torch_device)
- text2img = StableDiffusionPipeline(**inpaint.components, image_encoder=None).to(torch_device)
+ img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
+ text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
prompt = "A painting of a squirrel eating a burger"
@@ -1324,29 +1225,6 @@ def test_set_component_to_none(self):
assert out_image.shape == (1, 64, 64, 3)
assert np.abs(out_image - out_image_2).max() < 1e-3
- def test_optional_components_is_none(self):
- unet = self.dummy_cond_unet()
- scheduler = PNDMScheduler(skip_prk_steps=True)
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- items = {
- "feature_extractor": self.dummy_extractor,
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": bert,
- "tokenizer": tokenizer,
- "safety_checker": None,
- # we don't add an image encoder
- }
-
- pipeline = StableDiffusionPipeline(**items)
-
- assert sorted(pipeline.components.keys()) == sorted(["image_encoder"] + list(items.keys()))
- assert pipeline.image_encoder is None
-
def test_set_scheduler_consistency(self):
unet = self.dummy_cond_unet()
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
@@ -1575,15 +1453,28 @@ def test_name_or_path(self):
assert sd.name_or_path == tmpdirname
- def test_error_no_variant_available(self):
+ def test_warning_no_variant_available(self):
variant = "fp16"
- with self.assertRaises(ValueError) as error_context:
- _ = StableDiffusionPipeline.download(
+ with self.assertWarns(FutureWarning) as warning_context:
+ cached_folder = StableDiffusionPipeline.download(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", variant=variant
)
- assert "but no such modeling files are available" in str(error_context.exception)
- assert variant in str(error_context.exception)
+ assert "but no such modeling files are available" in str(warning_context.warning)
+ assert variant in str(warning_context.warning)
+
+ def get_all_filenames(directory):
+ filenames = glob.glob(directory + "/**", recursive=True)
+ filenames = [f for f in filenames if os.path.isfile(f)]
+ return filenames
+
+ filenames = get_all_filenames(str(cached_folder))
+
+ all_model_files, variant_model_files = variant_compatible_siblings(filenames, variant=variant)
+
+ # make sure that none of the model names are variant model names
+ assert len(variant_model_files) == 0
+ assert len(all_model_files) > 0
def test_pipe_to(self):
unet = self.dummy_cond_unet()
diff --git a/tests/pipelines/test_pipelines_auto.py b/tests/pipelines/test_pipelines_auto.py
index 1cd29565b8de..bfdedd25babe 100644
--- a/tests/pipelines/test_pipelines_auto.py
+++ b/tests/pipelines/test_pipelines_auto.py
@@ -156,54 +156,6 @@ def test_from_pipe_controlnet_new_task(self):
assert pipe_inpaint.__class__.__name__ == "StableDiffusionInpaintPipeline"
assert "controlnet" not in pipe_inpaint.components
- # testing `from_pipe` for text2img controlnet
- ## 1. from a different controlnet pipe, without controlnet argument
- pipe_control_text2img = AutoPipelineForText2Image.from_pipe(pipe_control_img2img)
- assert pipe_control_text2img.__class__.__name__ == "StableDiffusionControlNetPipeline"
- assert "controlnet" in pipe_control_text2img.components
-
- ## 2. from a different controlnet pipe, with controlnet argument
- pipe_control_text2img = AutoPipelineForText2Image.from_pipe(pipe_control_img2img, controlnet=controlnet)
- assert pipe_control_text2img.__class__.__name__ == "StableDiffusionControlNetPipeline"
- assert "controlnet" in pipe_control_text2img.components
-
- ## 3. from same controlnet pipeline class, with a different controlnet component
- pipe_control_text2img = AutoPipelineForText2Image.from_pipe(pipe_control_text2img, controlnet=controlnet)
- assert pipe_control_text2img.__class__.__name__ == "StableDiffusionControlNetPipeline"
- assert "controlnet" in pipe_control_text2img.components
-
- # testing from_pipe for inpainting
- ## 1. from a different controlnet pipeline class
- pipe_control_inpaint = AutoPipelineForInpainting.from_pipe(pipe_control_img2img)
- assert pipe_control_inpaint.__class__.__name__ == "StableDiffusionControlNetInpaintPipeline"
- assert "controlnet" in pipe_control_inpaint.components
-
- ## from a different controlnet pipe, with a different controlnet
- pipe_control_inpaint = AutoPipelineForInpainting.from_pipe(pipe_control_img2img, controlnet=controlnet)
- assert pipe_control_inpaint.__class__.__name__ == "StableDiffusionControlNetInpaintPipeline"
- assert "controlnet" in pipe_control_inpaint.components
-
- ## from same controlnet pipe, with a different controlnet
- pipe_control_inpaint = AutoPipelineForInpainting.from_pipe(pipe_control_inpaint, controlnet=controlnet)
- assert pipe_control_inpaint.__class__.__name__ == "StableDiffusionControlNetInpaintPipeline"
- assert "controlnet" in pipe_control_inpaint.components
-
- # testing from_pipe from img2img controlnet
- ## from a different controlnet pipe, without controlnet argument
- pipe_control_img2img = AutoPipelineForImage2Image.from_pipe(pipe_control_text2img)
- assert pipe_control_img2img.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline"
- assert "controlnet" in pipe_control_img2img.components
-
- # from a different controlnet pipe, with a different controlnet component
- pipe_control_img2img = AutoPipelineForImage2Image.from_pipe(pipe_control_text2img, controlnet=controlnet)
- assert pipe_control_img2img.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline"
- assert "controlnet" in pipe_control_img2img.components
-
- # from same controlnet pipeline class, with a different controlnet
- pipe_control_img2img = AutoPipelineForImage2Image.from_pipe(pipe_control_img2img, controlnet=controlnet)
- assert pipe_control_img2img.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline"
- assert "controlnet" in pipe_control_img2img.components
-
@slow
class AutoPipelineIntegrationTest(unittest.TestCase):
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index e11175921184..6f2674a7b8f6 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -17,16 +17,7 @@
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import diffusers
-from diffusers import (
- AsymmetricAutoencoderKL,
- AutoencoderKL,
- AutoencoderTiny,
- ConsistencyDecoderVAE,
- DDIMScheduler,
- DiffusionPipeline,
- StableDiffusionPipeline,
- UNet2DConditionModel,
-)
+from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging
@@ -37,12 +28,6 @@
torch_device,
)
-from ..models.test_models_vae import (
- get_asym_autoencoder_kl_config,
- get_autoencoder_kl_config,
- get_autoencoder_tiny_config,
- get_consistency_vae_config,
-)
from ..others.test_utils import TOKEN, USER, is_staging_test
@@ -186,34 +171,6 @@ def test_latents_input(self):
max_diff = np.abs(out - out_latents_inputs).max()
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
- def test_multi_vae(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- block_out_channels = pipe.vae.config.block_out_channels
- norm_num_groups = pipe.vae.config.norm_num_groups
-
- vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
- configs = [
- get_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_consistency_vae_config(block_out_channels, norm_num_groups),
- get_autoencoder_tiny_config(block_out_channels),
- ]
-
- out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- for vae_cls, config in zip(vae_classes, configs):
- vae = vae_cls(**config)
- vae = vae.to(torch_device)
- components["vae"] = vae
- vae_pipe = self.pipeline_class(**components)
- out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- assert out_vae_np.shape == out_np.shape
-
@require_torch
class PipelineKarrasSchedulerTesterMixin:
@@ -274,6 +231,8 @@ class PipelineTesterMixin:
"latents",
"output_type",
"return_dict",
+ "callback",
+ "callback_steps",
]
)
@@ -335,20 +294,6 @@ def batch_params(self) -> frozenset:
"See existing pipeline tests for reference."
)
- @property
- def callback_cfg_params(self) -> frozenset:
- raise NotImplementedError(
- "You need to set the attribute `callback_cfg_params` in the child test class that requires to run test_callback_cfg. "
- "`callback_cfg_params` are the parameters that needs to be passed to the pipeline's callback "
- "function when dynamically adjusting `guidance_scale`. They are variables that require special"
- "treatment when `do_classifier_free_guidance` is `True`. `pipeline_params.py` provides some common"
- " sets of parameters such as `TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS`. If your pipeline's "
- "set of cfg arguments has minor changes from one of the common sets of cfg arguments, "
- "do not make modifications to the existing common sets of cfg arguments. I.e. for inpaint pipeine, you "
- " need to adjust batch size of `mask` and `masked_image_latents` so should set the attribute as"
- "`callback_cfg_params = TEXT_TO_IMAGE_CFG_PARAMS.union({'mask', 'masked_image_latents'})`"
- )
-
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
@@ -536,7 +481,7 @@ def _test_inference_batch_single_identical(
assert output_batch[0].shape[0] == batch_size
- max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
+ max_diff = np.abs(output_batch[0][0] - output[0][0]).max()
assert max_diff < expected_max_diff
def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
@@ -745,7 +690,7 @@ def _test_attention_slicing_forward_pass(
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
if test_mean_pixel_difference:
- assert_mean_pixel_difference(to_np(output_with_slicing[0]), to_np(output_without_slicing[0]))
+ assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
@@ -797,15 +742,6 @@ def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
- offloaded_modules = [
- v
- for k, v in pipe.components.items()
- if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
- ]
- (
- self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
- f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
- )
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
@@ -917,107 +853,6 @@ def test_cfg(self):
assert out_cfg.shape == out_no_cfg.shape
- def test_callback_inputs(self):
- sig = inspect.signature(self.pipeline_class.__call__)
- has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
- has_callback_step_end = "callback_on_step_end" in sig.parameters
-
- if not (has_callback_tensor_inputs and has_callback_step_end):
- return
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- self.assertTrue(
- hasattr(pipe, "_callback_tensor_inputs"),
- f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
- )
-
- def callback_inputs_subset(pipe, i, t, callback_kwargs):
- # interate over callback args
- for tensor_name, tensor_value in callback_kwargs.items():
- # check that we're only passing in allowed tensor inputs
- assert tensor_name in pipe._callback_tensor_inputs
-
- return callback_kwargs
-
- def callback_inputs_all(pipe, i, t, callback_kwargs):
- for tensor_name in pipe._callback_tensor_inputs:
- assert tensor_name in callback_kwargs
-
- # interate over callback args
- for tensor_name, tensor_value in callback_kwargs.items():
- # check that we're only passing in allowed tensor inputs
- assert tensor_name in pipe._callback_tensor_inputs
-
- return callback_kwargs
-
- inputs = self.get_dummy_inputs(torch_device)
-
- # Test passing in a subset
- inputs["callback_on_step_end"] = callback_inputs_subset
- inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
- inputs["output_type"] = "latent"
- output = pipe(**inputs)[0]
-
- # Test passing in a everything
- inputs["callback_on_step_end"] = callback_inputs_all
- inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
- inputs["output_type"] = "latent"
- output = pipe(**inputs)[0]
-
- def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
- is_last = i == (pipe.num_timesteps - 1)
- if is_last:
- callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
- return callback_kwargs
-
- inputs["callback_on_step_end"] = callback_inputs_change_tensor
- inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
- inputs["output_type"] = "latent"
- output = pipe(**inputs)[0]
- assert output.abs().sum() == 0
-
- def test_callback_cfg(self):
- sig = inspect.signature(self.pipeline_class.__call__)
- has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
- has_callback_step_end = "callback_on_step_end" in sig.parameters
-
- if not (has_callback_tensor_inputs and has_callback_step_end):
- return
-
- if "guidance_scale" not in sig.parameters:
- return
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- self.assertTrue(
- hasattr(pipe, "_callback_tensor_inputs"),
- f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
- )
-
- def callback_increase_guidance(pipe, i, t, callback_kwargs):
- pipe._guidance_scale += 1.0
-
- return callback_kwargs
-
- inputs = self.get_dummy_inputs(torch_device)
-
- # use cfg guidance because some pipelines modify the shape of the latents
- # outside of the denoising loop
- inputs["guidance_scale"] = 2.0
- inputs["callback_on_step_end"] = callback_increase_guidance
- inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
- _ = pipe(**inputs)[0]
-
- # we increase the guidance scale by 1.0 at every step
- # check that the guidance scale is increased by the number of scheduler timesteps
- # accounts for models that modify the number of inference steps based on strength
- assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps)
-
@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):
@@ -1139,150 +974,6 @@ def test_push_to_hub_in_organization(self):
delete_repo(self.org_repo_id, token=TOKEN)
-# For SDXL and its derivative pipelines (such as ControlNet), we have the text encoders
-# and the tokenizers as optional components. So, we need to override the `test_save_load_optional_components()`
-# test for all such pipelines. This requires us to use a custom `encode_prompt()` function.
-class SDXLOptionalComponentsTesterMixin:
- def encode_prompt(
- self, tokenizers, text_encoders, prompt: str, num_images_per_prompt: int = 1, negative_prompt: str = None
- ):
- device = text_encoders[0].device
-
- if isinstance(prompt, str):
- prompt = [prompt]
- batch_size = len(prompt)
-
- prompt_embeds_list = []
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
- text_inputs = tokenizer(
- prompt,
- padding="max_length",
- max_length=tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
-
- text_input_ids = text_inputs.input_ids
-
- prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
- pooled_prompt_embeds = prompt_embeds[0]
- prompt_embeds = prompt_embeds.hidden_states[-2]
- prompt_embeds_list.append(prompt_embeds)
-
- prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
-
- if negative_prompt is None:
- negative_prompt_embeds = torch.zeros_like(prompt_embeds)
- negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
- else:
- negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
-
- negative_prompt_embeds_list = []
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
- uncond_input = tokenizer(
- negative_prompt,
- padding="max_length",
- max_length=tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
-
- negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(device), output_hidden_states=True)
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
- negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
- negative_prompt_embeds_list.append(negative_prompt_embeds)
-
- negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
-
- bs_embed, seq_len, _ = prompt_embeds.shape
-
- # duplicate text embeddings for each generation per prompt, using mps friendly method
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
-
- # for classifier-free guidance
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
- seq_len = negative_prompt_embeds.shape[1]
-
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
-
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
- bs_embed * num_images_per_prompt, -1
- )
-
- # for classifier-free guidance
- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
- bs_embed * num_images_per_prompt, -1
- )
-
- return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
-
- def _test_save_load_optional_components(self, expected_max_difference=1e-4):
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- for optional_component in pipe._optional_components:
- setattr(pipe, optional_component, None)
-
- for component in pipe.components.values():
- if hasattr(component, "set_default_attn_processor"):
- component.set_default_attn_processor()
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator_device = "cpu"
- inputs = self.get_dummy_inputs(generator_device)
-
- tokenizer = components.pop("tokenizer")
- tokenizer_2 = components.pop("tokenizer_2")
- text_encoder = components.pop("text_encoder")
- text_encoder_2 = components.pop("text_encoder_2")
-
- tokenizers = [tokenizer, tokenizer_2] if tokenizer is not None else [tokenizer_2]
- text_encoders = [text_encoder, text_encoder_2] if text_encoder is not None else [text_encoder_2]
- prompt = inputs.pop("prompt")
- (
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ) = self.encode_prompt(tokenizers, text_encoders, prompt)
- inputs["prompt_embeds"] = prompt_embeds
- inputs["negative_prompt_embeds"] = negative_prompt_embeds
- inputs["pooled_prompt_embeds"] = pooled_prompt_embeds
- inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
-
- output = pipe(**inputs)[0]
-
- with tempfile.TemporaryDirectory() as tmpdir:
- pipe.save_pretrained(tmpdir)
- pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
- for component in pipe_loaded.components.values():
- if hasattr(component, "set_default_attn_processor"):
- component.set_default_attn_processor()
- pipe_loaded.to(torch_device)
- pipe_loaded.set_progress_bar_config(disable=None)
-
- for optional_component in pipe._optional_components:
- self.assertTrue(
- getattr(pipe_loaded, optional_component) is None,
- f"`{optional_component}` did not stay set to None after loading.",
- )
-
- inputs = self.get_dummy_inputs(generator_device)
- _ = inputs.pop("prompt")
- inputs["prompt_embeds"] = prompt_embeds
- inputs["negative_prompt_embeds"] = negative_prompt_embeds
- inputs["pooled_prompt_embeds"] = pooled_prompt_embeds
- inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
-
- output_loaded = pipe_loaded(**inputs)[0]
-
- max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
- self.assertLess(max_diff, expected_max_difference)
-
-
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
index e9f435239c92..933583ce4b70 100644
--- a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
+++ b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
@@ -62,8 +62,8 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet3DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=1,
+ block_out_channels=(32, 32),
+ layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
@@ -71,7 +71,6 @@ def get_dummy_components(self):
up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
cross_attention_dim=4,
attention_head_dim=4,
- norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
@@ -82,14 +81,13 @@ def get_dummy_components(self):
)
torch.manual_seed(0)
vae = AutoencoderKL(
- block_out_channels=(8,),
+ block_out_channels=(32,),
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D"],
latent_channels=4,
sample_size=32,
- norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
@@ -144,11 +142,10 @@ def test_text_to_video_default_case(self):
image_slice = frames[0][-3:, -3:, -1]
assert frames[0].shape == (32, 32, 3)
- expected_slice = np.array([192.0, 44.0, 157.0, 140.0, 108.0, 104.0, 123.0, 144.0, 129.0])
+ expected_slice = np.array([91.0, 152.0, 66.0, 192.0, 94.0, 126.0, 101.0, 123.0, 152.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- @unittest.skipIf(torch_device != "cuda", reason="Feature isn't heavily used. Test in CUDA environment only.")
def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False, expected_max_diff=3e-3)
diff --git a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py
index 1785eb967f16..b5fe3451774b 100644
--- a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py
+++ b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py
@@ -70,16 +70,15 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet3DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=1,
+ block_out_channels=(32, 64, 64, 64),
+ layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
- down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
- up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
+ down_block_types=("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D"),
+ up_block_types=("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
cross_attention_dim=32,
attention_head_dim=4,
- norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
@@ -90,18 +89,13 @@ def get_dummy_components(self):
)
torch.manual_seed(0)
vae = AutoencoderKL(
- block_out_channels=[
- 8,
- ],
+ block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
- down_block_types=[
- "DownEncoderBlock2D",
- ],
- up_block_types=["UpDecoderBlock2D"],
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
- sample_size=32,
- norm_num_groups=2,
+ sample_size=128,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
@@ -160,7 +154,7 @@ def test_text_to_video_default_case(self):
image_slice = frames[0][-3:, -3:, -1]
assert frames[0].shape == (32, 32, 3)
- expected_slice = np.array([162.0, 136.0, 132.0, 140.0, 139.0, 137.0, 169.0, 134.0, 132.0])
+ expected_slice = np.array([106, 117, 113, 174, 137, 112, 148, 151, 131])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
index bcc2237c92d6..b567f507d1d2 100644
--- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
+++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
@@ -232,9 +232,3 @@ def test_inference_batch_single_identical(self):
@unittest.skip(reason="flakey and float16 requires CUDA")
def test_float16_inference(self):
super().test_float16_inference()
-
- def test_callback_inputs(self):
- pass
-
- def test_callback_cfg(self):
- pass
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
index 029680b677f0..1442196251d6 100644
--- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
+++ b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
@@ -43,7 +43,6 @@ class WuerstchenDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
"return_dict",
]
test_xformers_attention = False
- callback_cfg_params = ["image_embeddings", "text_encoder_hidden_states"]
@property
def text_embedder_hidden_size(self):
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
index 5e1b89c0d2e0..a85ec0e2c102 100644
--- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
+++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
@@ -17,24 +17,11 @@
import numpy as np
import torch
-import torch.nn as nn
-import torch.nn.functional as F
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline
-from diffusers.loaders import AttnProcsLayers
-from diffusers.models.attention_processor import (
- LoRAAttnProcessor,
- LoRAAttnProcessor2_0,
-)
from diffusers.pipelines.wuerstchen import WuerstchenPrior
-from diffusers.utils.import_utils import is_peft_available
-from diffusers.utils.testing_utils import enable_full_determinism, require_peft_backend, skip_mps, torch_device
-
-
-if is_peft_available():
- from peft import LoraConfig
- from peft.tuners.tuners_utils import BaseTunerLayer
+from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
from ..test_pipelines_common import PipelineTesterMixin
@@ -42,19 +29,6 @@
enable_full_determinism()
-def create_prior_lora_layers(unet: nn.Module):
- lora_attn_procs = {}
- for name in unet.attn_processors.keys():
- lora_attn_processor_class = (
- LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
- )
- lora_attn_procs[name] = lora_attn_processor_class(
- hidden_size=unet.config.c,
- )
- unet_lora_layers = AttnProcsLayers(lora_attn_procs)
- return lora_attn_procs, unet_lora_layers
-
-
class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WuerstchenPriorPipeline
params = ["prompt"]
@@ -70,7 +44,6 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"return_dict",
]
test_xformers_attention = False
- callback_cfg_params = ["text_encoder_hidden_states"]
@property
def text_embedder_hidden_size(self):
@@ -210,87 +183,3 @@ def test_attention_slicing_forward_pass(self):
@unittest.skip(reason="flaky for now")
def test_float16_inference(self):
super().test_float16_inference()
-
- # override because we need to make sure latent_mean and latent_std to be 0
- def test_callback_inputs(self):
- components = self.get_dummy_components()
- components["latent_mean"] = 0
- components["latent_std"] = 0
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- self.assertTrue(
- hasattr(pipe, "_callback_tensor_inputs"),
- f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
- )
-
- def callback_inputs_test(pipe, i, t, callback_kwargs):
- missing_callback_inputs = set()
- for v in pipe._callback_tensor_inputs:
- if v not in callback_kwargs:
- missing_callback_inputs.add(v)
- self.assertTrue(
- len(missing_callback_inputs) == 0, f"Missing callback tensor inputs: {missing_callback_inputs}"
- )
- last_i = pipe.num_timesteps - 1
- if i == last_i:
- callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
- return callback_kwargs
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["callback_on_step_end"] = callback_inputs_test
- inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
- inputs["output_type"] = "latent"
-
- output = pipe(**inputs)[0]
- assert output.abs().sum() == 0
-
- def check_if_lora_correctly_set(self, model) -> bool:
- """
- Checks if the LoRA layers are correctly set with peft
- """
- for module in model.modules():
- if isinstance(module, BaseTunerLayer):
- return True
- return False
-
- def get_lora_components(self):
- prior = self.dummy_prior
-
- prior_lora_config = LoraConfig(
- r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
- )
-
- prior_lora_attn_procs, prior_lora_layers = create_prior_lora_layers(prior)
-
- lora_components = {
- "prior_lora_layers": prior_lora_layers,
- "prior_lora_attn_procs": prior_lora_attn_procs,
- }
-
- return prior, prior_lora_config, lora_components
-
- @require_peft_backend
- def test_inference_with_prior_lora(self):
- _, prior_lora_config, _ = self.get_lora_components()
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output_no_lora = pipe(**self.get_dummy_inputs(device))
- image_embed = output_no_lora.image_embeddings
- self.assertTrue(image_embed.shape == (1, 2, 24, 24))
-
- pipe.prior.add_adapter(prior_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.prior), "Lora not correctly set in prior")
-
- output_lora = pipe(**self.get_dummy_inputs(device))
- lora_image_embed = output_lora.image_embeddings
-
- self.assertTrue(image_embed.shape == lora_image_embed.shape)
diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py
index 7fe71941b4e7..6e6442e0daf6 100644
--- a/tests/schedulers/test_scheduler_dpm_multi.py
+++ b/tests/schedulers/test_scheduler_dpm_multi.py
@@ -29,7 +29,6 @@ def get_scheduler_config(self, **kwargs):
"algorithm_type": "dpmsolver++",
"solver_type": "midpoint",
"lower_order_final": False,
- "euler_at_final": False,
"lambda_min_clipped": -float("inf"),
"variance_type": None,
}
@@ -196,10 +195,6 @@ def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)
- def test_euler_at_final(self):
- self.check_over_configs(euler_at_final=True)
- self.check_over_configs(euler_at_final=False)
-
def test_lambda_min_clipped(self):
self.check_over_configs(lambda_min_clipped=-float("inf"))
self.check_over_configs(lambda_min_clipped=-5.1)
@@ -263,12 +258,6 @@ def test_full_loop_with_karras_and_v_prediction(self):
assert abs(result_mean.item() - 0.2096) < 1e-3
- def test_full_loop_with_lu_and_v_prediction(self):
- sample = self.full_loop(prediction_type="v_prediction", use_lu_lambdas=True)
- result_mean = torch.mean(torch.abs(sample))
-
- assert abs(result_mean.item() - 0.1554) < 1e-3
-
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py
index 3249d7032bad..fa885a0542eb 100644
--- a/tests/schedulers/test_scheduler_euler.py
+++ b/tests/schedulers/test_scheduler_euler.py
@@ -37,14 +37,6 @@ def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
- def test_timestep_type(self):
- timestep_types = ["discrete", "continuous"]
- for timestep_type in timestep_types:
- self.check_over_configs(timestep_type=timestep_type)
-
- def test_karras_sigmas(self):
- self.check_over_configs(use_karras_sigmas=True, sigma_min=0.02, sigma_max=700.0)
-
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py
index 08c5ad5c3a50..8bc95b38cf34 100755
--- a/tests/schedulers/test_schedulers.py
+++ b/tests/schedulers/test_schedulers.py
@@ -352,8 +352,8 @@ def check_over_configs(self, time_step=0, **config):
_ = scheduler.scale_model_input(sample, scaled_sigma_max)
_ = new_scheduler.scale_model_input(sample, scaled_sigma_max)
elif scheduler_class != VQDiffusionScheduler:
- _ = scheduler.scale_model_input(sample, scheduler.timesteps[-1])
- _ = new_scheduler.scale_model_input(sample, scheduler.timesteps[-1])
+ _ = scheduler.scale_model_input(sample, 0)
+ _ = new_scheduler.scale_model_input(sample, 0)
# Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
diff --git a/train_text_to_image.py b/train_text_to_image.py
new file mode 100644
index 000000000000..4d438b397973
--- /dev/null
+++ b/train_text_to_image.py
@@ -0,0 +1,988 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+from transformers.utils import ContextManagers
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
+from diffusers.utils.import_utils import is_xformers_available
+from llmga.diffusers.my_datasets.dataset_text2img import Text2ImgTrainDataset
+from llmga.diffusers.my_utils.util import get_unweighted_text_embeddings, get_text_index
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.22.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images=None,
+ repo_folder=None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {args.pretrained_model_name_or_path}
+datasets:
+- {args.dataset_name}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+inference: true
+---
+ """
+ model_card = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
+prompt = "{args.validation_prompts[0]}"
+image = pipeline(prompt).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_card += wandb_info
+
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=accelerator.unwrap_model(vae),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=accelerator.unwrap_model(unet),
+ safety_checker=None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
+ # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
+ # will try to assign the same optimizer with the same weights to all models during
+ # `deepspeed.initialize`, which of course doesn't work.
+ #
+ # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
+ # frozen models from being partitioned during `zero.Init` which gets called during
+ # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
+ # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
+ )
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
+ )
+
+ # Freeze vae and text_encoder and set unet to trainable
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.train()
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ train_dataset = Text2ImgTrainDataset(args.train_data_dir,args)
+
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ prompts=[example["caption"] for example in examples]
+ input_ids = get_text_index(tokenizer,prompts)
+
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ args.mixed_precision = accelerator.mixed_precision
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ args.mixed_precision = accelerator.mixed_precision
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
+ )
+ if args.input_perturbation:
+ new_noise = noise + args.input_perturbation * torch.randn_like(noise)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ if args.input_perturbation:
+ noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
+ else:
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = get_unweighted_text_embeddings(text_encoder,batch["input_ids"],chunk_length=tokenizer.model_max_length)
+ # encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+ log_validation(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ revision=args.revision,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.torch_dtype = weight_dtype
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/train_text_to_image_inpaint.py b/train_text_to_image_inpaint.py
new file mode 100644
index 000000000000..0b903eb7cd66
--- /dev/null
+++ b/train_text_to_image_inpaint.py
@@ -0,0 +1,1018 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+from transformers.utils import ContextManagers
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
+from diffusers.utils.import_utils import is_xformers_available
+from my_datasets.dataset_text2img import Text2ImgTrainDataset
+from my_datasets.dataset_inpainting import InpaintingTextTrainDataset
+from my_utils.util import get_unweighted_text_embeddings, get_text_index
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.22.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+# def prepare_mask_latents(
+# self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+# ):
+# # resize the mask to latents shape as we concatenate the mask to the latents
+# # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+# # and half precision
+# mask = torch.nn.functional.interpolate(
+# mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+# )
+# mask = mask.to(device=device, dtype=dtype)
+
+# masked_image = masked_image.to(device=device, dtype=dtype)
+
+# # aligning device to prevent device errors when concating it with the latent model input
+# masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+# return mask, masked_image_latents
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images=None,
+ repo_folder=None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {args.pretrained_model_name_or_path}
+datasets:
+- {args.dataset_name}
+tags:
+- stable-diffusion
+- stable-diffusion-diffusers
+- text-to-image
+- diffusers
+inference: true
+---
+ """
+ model_card = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
+prompt = "{args.validation_prompts[0]}"
+image = pipeline(prompt).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_card += wandb_info
+
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=accelerator.unwrap_model(vae),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=accelerator.unwrap_model(unet),
+ safety_checker=None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+ parser.add_argument('--blank_mask_prob', type=float, default=0.0, help='Train inpainting UNet with blank mask')
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
+ # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
+ # will try to assign the same optimizer with the same weights to all models during
+ # `deepspeed.initialize`, which of course doesn't work.
+ #
+ # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
+ # frozen models from being partitioned during `zero.Init` which gets called during
+ # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
+ # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
+ )
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
+ )
+
+ # Freeze vae and text_encoder and set unet to trainable
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.train()
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ train_dataset = InpaintingTextTrainDataset(args.train_data_dir,args)
+
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
+
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ masked_image = torch.stack([example["masked_image"] for example in examples])
+ mask = torch.stack([example["mask"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ masked_image = masked_image.to(memory_format=torch.contiguous_format).float()
+ mask = mask.to(memory_format=torch.contiguous_format).float()
+ prompts=[example["caption"] for example in examples]
+ input_ids = get_text_index(tokenizer,prompts)
+ return {"pixel_values": pixel_values, "input_ids": input_ids,"masked_image":masked_image,"mask":mask}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ args.mixed_precision = accelerator.mixed_precision
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ args.mixed_precision = accelerator.mixed_precision
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ masked_latents = vae.encode(batch["masked_image"].to(weight_dtype)).latent_dist.sample()
+ masked_latents = masked_latents * vae.config.scaling_factor
+ mask = torch.nn.functional.interpolate(batch["mask"], size=masked_latents.shape[-2:])
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
+ )
+ if args.input_perturbation:
+ new_noise = noise + args.input_perturbation * torch.randn_like(noise)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ if args.input_perturbation:
+ noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
+ else:
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ noisy_latents = torch.cat([noisy_latents, mask, masked_latents], 1)
+
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = get_unweighted_text_embeddings(text_encoder,batch["input_ids"],chunk_length=tokenizer.model_max_length)
+ # encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+ log_validation(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ revision=args.revision,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.torch_dtype = weight_dtype
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/train_text_to_image_sdxl.py b/train_text_to_image_sdxl.py
new file mode 100644
index 000000000000..63b75af177f4
--- /dev/null
+++ b/train_text_to_image_sdxl.py
@@ -0,0 +1,1163 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion XL for text2image."""
+
+import argparse
+import functools
+import gc
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+from my_utils.util import get_unweighted_text_embeddings_SDXL1, get_unweighted_text_embeddings_SDXL2, get_text_index
+from my_datasets.dataset_text2img_sdxl import Text2ImgTrainDataset
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.22.0.dev0")
+
+logger = get_logger(__name__)
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ validation_prompt=None,
+ base_model=str,
+ dataset_name=str,
+ repo_folder=None,
+ vae_path=None,
+):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+dataset: {dataset_name}
+tags:
+- stable-diffusion-xl
+- stable-diffusion-xl-diffusers
+- text-to-image
+- diffusers
+inference: true
+---
+ """
+ model_card = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{base_model}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
+{img_str}
+
+Special VAE used for training: {vae_path}.
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sdxl-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--timestep_bias_strategy",
+ type=str,
+ default="none",
+ choices=["earlier", "later", "range", "none"],
+ help=(
+ "The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
+ " Choices: ['earlier', 'later', 'range', 'none']."
+ " The default is 'none', which means no bias is applied, and training proceeds normally."
+ " The value of 'later' will increase the frequency of the model's final training timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_multiplier",
+ type=float,
+ default=1.0,
+ help=(
+ "The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
+ " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_begin",
+ type=int,
+ default=0,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
+ " Defaults to zero, which equates to having no specific bias."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_end",
+ type=int,
+ default=1000,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
+ " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_portion",
+ type=float,
+ default=0.25,
+ help=(
+ "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
+ " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
+ " whether the biased portions are in the earlier or later timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(batch, device, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+ prompt_batch = batch["prompts"]
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ fg=0
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_input_ids = get_text_index(tokenizer,captions)
+
+ if fg==0:
+ pooled_prompt_embeds, prompt_embeds = get_unweighted_text_embeddings_SDXL1(text_encoder,text_input_ids.to(device),chunk_length=tokenizer.model_max_length)
+ fg=1
+ else:
+ pooled_prompt_embeds, prompt_embeds = get_unweighted_text_embeddings_SDXL2(text_encoder,text_input_ids.to(device),chunk_length=tokenizer.model_max_length)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ # pooled_prompt_embeds = prompt_embeds[0]
+ # prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+ seq_len0=prompt_embeds_list[0].shape[1]
+ seq_len1=prompt_embeds_list[1].shape[1]
+ if seq_len0seq_len1:
+ prompt_embeds_list[0]=prompt_embeds_list[0][:,:seq_len1,:]
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ batch.update({"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds})
+ return batch
+
+
+def compute_vae_encodings(batch, device, vae):
+ pixel_values = batch["pixel_values"]
+
+ with torch.no_grad():
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ batch.update({"model_input": model_input})
+ return batch
+
+
+def generate_timestep_weights(args, num_timesteps):
+ weights = torch.ones(num_timesteps)
+
+ # Determine the indices to bias
+ num_to_bias = int(args.timestep_bias_portion * num_timesteps)
+
+ if args.timestep_bias_strategy == "later":
+ bias_indices = slice(-num_to_bias, None)
+ elif args.timestep_bias_strategy == "earlier":
+ bias_indices = slice(0, num_to_bias)
+ elif args.timestep_bias_strategy == "range":
+ # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
+ range_begin = args.timestep_bias_begin
+ range_end = args.timestep_bias_end
+ if range_begin < 0:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
+ )
+ if range_end > num_timesteps:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
+ )
+ bias_indices = slice(range_begin, range_end)
+ else: # 'none' or any other string
+ return weights
+ if args.timestep_bias_multiplier <= 0:
+ return ValueError(
+ "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
+ " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
+ " A timestep bias multiplier less than or equal to 0 is not allowed."
+ )
+
+ # Apply the bias
+ weights[bias_indices] *= args.timestep_bias_multiplier
+
+ # Normalize
+ weights /= weights.sum()
+
+ return weights
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ # Check for terminal SNR in combination with SNR Gamma
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # Freeze vae and text encoders.
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ # Set unet as trainable.
+ unet.train()
+
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = unet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+ train_dataset = Text2ImgTrainDataset(args.train_data_dir,args)
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory. We will pre-compute the VAE encodings too.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+ compute_embeddings_fn = functools.partial(
+ encode_prompt,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ )
+ compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
+
+ def collate_fn(examples):
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ prompts=[example["caption"] for example in examples]
+
+ return {
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ "pixel_values": pixel_values,
+ "prompts": prompts
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Sample noise that we'll add to the latents
+ batch = compute_embeddings_fn(batch,accelerator.device)
+
+ batch = compute_vae_encodings_fn(batch,accelerator.device)
+ model_input = batch["model_input"].to(accelerator.device)
+ noise = torch.randn_like(model_input)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
+ )
+
+ bsz = model_input.shape[0]
+ if args.timestep_bias_strategy == "none":
+ # Sample a random timestep for each image without bias.
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ else:
+ # Sample a random timestep for each image, potentially biased by the timestep weights.
+ # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
+ weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
+ model_input.device
+ )
+ timesteps = torch.multinomial(weights, bsz, replacement=True).long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ # Predict the noise residual
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
+ prompt_embeds = prompt_embeds
+ model_pred = unet(
+ noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ elif noise_scheduler.config.prediction_type == "sample":
+ # We set the target to latents here, but the model_pred will return the noise sample prediction.
+ target = model_input
+ # We will have to subtract the noise residual from the prediction to get the target sample.
+ model_pred = model_pred - noise
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = unet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+
+ # create pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ # Serialize pipeline.
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+ pipeline.save_pretrained(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline = pipeline.to(accelerator.device)
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id=repo_id,
+ images=images,
+ validation_prompt=args.validation_prompt,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/train_text_to_image_sdxl_inpainting.py b/train_text_to_image_sdxl_inpainting.py
new file mode 100644
index 000000000000..675545cc7f97
--- /dev/null
+++ b/train_text_to_image_sdxl_inpainting.py
@@ -0,0 +1,1176 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion XL for text2image."""
+
+import argparse
+import functools
+import gc
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+from my_utils.util import get_unweighted_text_embeddings_SDXL1, get_unweighted_text_embeddings_SDXL2, get_text_index
+from my_datasets.dataset_inpainting_sdxl import InpaintingTextTrainDataset
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.22.0.dev0")
+
+logger = get_logger(__name__)
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ validation_prompt=None,
+ base_model=str,
+ dataset_name=str,
+ repo_folder=None,
+ vae_path=None,
+):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"\n"
+
+ yaml = f"""
+---
+license: creativeml-openrail-m
+base_model: {base_model}
+dataset: {dataset_name}
+tags:
+- stable-diffusion-xl
+- stable-diffusion-xl-diffusers
+- text-to-image
+- diffusers
+inference: true
+---
+ """
+ model_card = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{base_model}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
+{img_str}
+
+Special VAE used for training: {vae_path}.
+"""
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
+ f.write(yaml + model_card)
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sdxl-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--timestep_bias_strategy",
+ type=str,
+ default="none",
+ choices=["earlier", "later", "range", "none"],
+ help=(
+ "The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
+ " Choices: ['earlier', 'later', 'range', 'none']."
+ " The default is 'none', which means no bias is applied, and training proceeds normally."
+ " The value of 'later' will increase the frequency of the model's final training timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_multiplier",
+ type=float,
+ default=1.0,
+ help=(
+ "The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
+ " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_begin",
+ type=int,
+ default=0,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
+ " Defaults to zero, which equates to having no specific bias."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_end",
+ type=int,
+ default=1000,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
+ " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_portion",
+ type=float,
+ default=0.25,
+ help=(
+ "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
+ " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
+ " whether the biased portions are in the earlier or later timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(batch, device, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+ prompt_batch = batch["prompts"]
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ fg=0
+
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+
+ text_input_ids = get_text_index(tokenizer,captions)
+ if fg==0:
+ pooled_prompt_embeds, prompt_embeds = get_unweighted_text_embeddings_SDXL1(text_encoder,text_input_ids.to(device),chunk_length=tokenizer.model_max_length)
+ fg=1
+ else:
+ pooled_prompt_embeds, prompt_embeds = get_unweighted_text_embeddings_SDXL2(text_encoder,text_input_ids.to(device),chunk_length=tokenizer.model_max_length)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+
+ seq_len0=prompt_embeds_list[0].shape[1]
+ seq_len1=prompt_embeds_list[1].shape[1]
+ if seq_len0seq_len1:
+ prompt_embeds_list[0]=prompt_embeds_list[0][:,:seq_len1,:]
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ # print("++++++++++++++",pooled_prompt_embeds.shape)
+ batch.update({"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds})
+ return batch
+
+
+def compute_vae_encodings(batch, device,in_name,out_name, vae):
+ pixel_values = batch[in_name]
+
+ with torch.no_grad():
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ batch.update({out_name: model_input})
+ return batch
+
+
+def generate_timestep_weights(args, num_timesteps):
+ weights = torch.ones(num_timesteps)
+
+ # Determine the indices to bias
+ num_to_bias = int(args.timestep_bias_portion * num_timesteps)
+
+ if args.timestep_bias_strategy == "later":
+ bias_indices = slice(-num_to_bias, None)
+ elif args.timestep_bias_strategy == "earlier":
+ bias_indices = slice(0, num_to_bias)
+ elif args.timestep_bias_strategy == "range":
+ # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
+ range_begin = args.timestep_bias_begin
+ range_end = args.timestep_bias_end
+ if range_begin < 0:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
+ )
+ if range_end > num_timesteps:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
+ )
+ bias_indices = slice(range_begin, range_end)
+ else: # 'none' or any other string
+ return weights
+ if args.timestep_bias_multiplier <= 0:
+ return ValueError(
+ "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
+ " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
+ " A timestep bias multiplier less than or equal to 0 is not allowed."
+ )
+
+ # Apply the bias
+ weights[bias_indices] *= args.timestep_bias_multiplier
+
+ # Normalize
+ weights /= weights.sum()
+
+ return weights
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ # Check for terminal SNR in combination with SNR Gamma
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+
+ # Freeze vae and text encoders.
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ # Set unet as trainable.
+ unet.train()
+
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = unet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+ train_dataset = InpaintingTextTrainDataset(args.train_data_dir,args)
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory. We will pre-compute the VAE encodings too.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+ compute_embeddings_fn = functools.partial(
+ encode_prompt,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ )
+ compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
+
+ def collate_fn(examples):
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ prompts=[example["caption"] for example in examples]
+ masked_image = torch.stack([example["masked_image"] for example in examples])
+ mask = torch.stack([example["mask"] for example in examples])
+
+ return {
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ "pixel_values": pixel_values,
+ "prompts": prompts,
+ "masked_image":masked_image,
+ "mask":mask,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Sample noise that we'll add to the latents
+ batch = compute_embeddings_fn(batch,accelerator.device)
+
+ batch = compute_vae_encodings_fn(batch,accelerator.device,"pixel_values","model_input")
+ model_input = batch["model_input"].to(accelerator.device)
+ noise = torch.randn_like(model_input)
+
+ batch = compute_vae_encodings_fn(batch,accelerator.device,"masked_image","masked_latents")
+
+ masked_latents = batch["masked_latents"].to(accelerator.device)
+
+ mask = torch.nn.functional.interpolate(batch["mask"], size=masked_latents.shape[-2:])
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
+ )
+
+ bsz = model_input.shape[0]
+ if args.timestep_bias_strategy == "none":
+ # Sample a random timestep for each image without bias.
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ else:
+ # Sample a random timestep for each image, potentially biased by the timestep weights.
+ # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
+ weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
+ model_input.device
+ )
+ timesteps = torch.multinomial(weights, bsz, replacement=True).long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+ noisy_model_input = torch.cat([noisy_model_input, mask, masked_latents], dim=1)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ # Predict the noise residual
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
+ prompt_embeds = prompt_embeds
+ model_pred = unet(
+ noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
+ ).sample
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ elif noise_scheduler.config.prediction_type == "sample":
+ # We set the target to latents here, but the model_pred will return the noise sample prediction.
+ target = model_input
+ # We will have to subtract the noise residual from the prediction to get the target sample.
+ model_pred = model_pred - noise
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.snr_gamma is None:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective requires that we add one to SNR values before we divide by them.
+ snr = snr + 1
+ mse_loss_weights = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = unet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+
+ # create pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ # Serialize pipeline.
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+ pipeline.save_pretrained(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline = pipeline.to(accelerator.device)
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id=repo_id,
+ images=images,
+ validation_prompt=args.validation_prompt,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py
index 5013e78303e2..5a80ed1c69dd 100644
--- a/utils/check_config_docstrings.py
+++ b/utils/check_config_docstrings.py
@@ -36,7 +36,7 @@
# Regex pattern used to find the checkpoint mentioned in the docstring of `config_class`.
# For example, `[bert-base-uncased](https://huggingface.co/bert-base-uncased)`
-_re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)")
+_re_checkpoint = re.compile("\[(.+?)\]\((https://huggingface\.co/.+?)\)")
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
diff --git a/utils/check_copies.py b/utils/check_copies.py
index 2563aff10dff..df5816b4ac03 100644
--- a/utils/check_copies.py
+++ b/utils/check_copies.py
@@ -17,7 +17,9 @@
import glob
import os
import re
-import subprocess
+
+import black
+from doc_builder.style_doc import style_docstrings_in_code
# All paths are set with the intent you should run this script from the root of the repo with the command
@@ -44,12 +46,7 @@ def find_code_in_diffusers(object_name):
if i >= len(parts):
raise ValueError(f"`object_name` should begin with the name of a module of diffusers but got {object_name}.")
- with open(
- os.path.join(DIFFUSERS_PATH, f"{module}.py"),
- "r",
- encoding="utf-8",
- newline="\n",
- ) as f:
+ with open(os.path.join(DIFFUSERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Now let's find the class / func in the code!
@@ -93,29 +90,17 @@ def get_indent(code):
return ""
-def run_ruff(code):
- command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
- process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
- stdout, _ = process.communicate(input=code.encode())
- return stdout.decode()
-
-
-def stylify(code: str) -> str:
+def blackify(code):
"""
- Applies the ruff part of our `make style` command to some code. This formats the code using `ruff format`.
- As `ruff` does not provide a python api this cannot be done on the fly.
-
- Args:
- code (`str`): The code to format.
-
- Returns:
- `str`: The formatted code.
+ Applies the black part of our `make style` command to `code`.
"""
has_indent = len(get_indent(code)) > 0
if has_indent:
code = f"class Bla:\n{code}"
- formatted_code = run_ruff(code)
- return formatted_code[len("class Bla:\n") :] if has_indent else formatted_code
+ mode = black.Mode(target_versions={black.TargetVersion.PY37}, line_length=119, preview=True)
+ result = black.format_str(code, mode=mode)
+ result, _ = style_docstrings_in_code(result)
+ return result[len("class Bla:\n") :] if has_indent else result
def is_copy_consistent(filename, overwrite=False):
@@ -175,9 +160,9 @@ def is_copy_consistent(filename, overwrite=False):
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
- # stylify after replacement. To be able to do that, we need the header (class or function definition)
+ # Blackify after replacement. To be able to do that, we need the header (class or function definition)
# from the previous line
- theoretical_code = stylify(lines[start_index - 1] + theoretical_code)
+ theoretical_code = blackify(lines[start_index - 1] + theoretical_code)
theoretical_code = theoretical_code[len(lines[start_index - 1]) :]
# Test for a diff and act accordingly.
@@ -212,11 +197,7 @@ def check_copies(overwrite: bool = False):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument(
- "--fix_and_overwrite",
- action="store_true",
- help="Whether to fix inconsistencies.",
- )
+ parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
check_copies(args.fix_and_overwrite)
diff --git a/utils/check_inits.py b/utils/check_inits.py
index 515419908f91..6b1cdb6fcefd 100644
--- a/utils/check_inits.py
+++ b/utils/check_inits.py
@@ -36,9 +36,9 @@
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
_re_import_struct_add_many = re.compile(r"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]")
# Catches a line with an object between quotes and a comma: "MyModel",
-_re_quote_object = re.compile(r'^\s+"([^"]+)",')
+_re_quote_object = re.compile('^\s+"([^"]+)",')
# Catches a line with objects between brackets only: ["foo", "bar"],
-_re_between_brackets = re.compile(r"^\s+\[([^\]]+)\]")
+_re_between_brackets = re.compile("^\s+\[([^\]]+)\]")
# Catches a line with from foo import bar, bla, boo
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
# Catches a line with try:
@@ -79,7 +79,7 @@ def parse_init(init_file):
# If we have everything on a single line, let's deal with it.
if _re_one_line_import_struct.search(line):
content = _re_one_line_import_struct.search(line).groups()[0]
- imports = re.findall(r"\[([^\]]+)\]", content)
+ imports = re.findall("\[([^\]]+)\]", content)
for imp in imports:
objects.extend([obj[1:-1] for obj in imp.split(", ")])
line_index += 1
diff --git a/utils/check_repo.py b/utils/check_repo.py
index 5f48d01d354e..6f0417d69065 100644
--- a/utils/check_repo.py
+++ b/utils/check_repo.py
@@ -533,7 +533,7 @@ def find_all_documented_objects():
for doc_file in Path(PATH_TO_DOC).glob("**/*.md"):
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
content = f.read()
- raw_doc_objs = re.findall(r"\[\[autodoc\]\]\s+(\S+)\s+", content)
+ raw_doc_objs = re.findall("\[\[autodoc\]\]\s+(\S+)\s+", content)
documented_obj += [obj.split(".")[-1] for obj in raw_doc_objs]
return documented_obj
diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py
index 2de3940342d0..e1e85974aeed 100644
--- a/utils/custom_init_isort.py
+++ b/utils/custom_init_isort.py
@@ -16,7 +16,7 @@
Utility that sorts the imports in the custom inits of Diffusers. Diffusers uses init files that delay the
import of an object to when it's actually needed. This is to avoid the main init importing all models, which would
make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with
-delayed imports have two halves: one defining a dictionary `_import_structure` which maps modules to the name of the
+delayed imports have two halves: one definining a dictionary `_import_structure` which maps modules to the name of the
objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. `isort` or `ruff`
properly sort the second half which looks like traditionl imports, the goal of this script is to sort the first half.
diff --git a/utils/fetch_torch_cuda_pipeline_test_matrix.py b/utils/fetch_torch_cuda_pipeline_test_matrix.py
index 302898789728..41a9c1c8270d 100644
--- a/utils/fetch_torch_cuda_pipeline_test_matrix.py
+++ b/utils/fetch_torch_cuda_pipeline_test_matrix.py
@@ -34,11 +34,8 @@ def filter_pipelines(usage_dict, usage_cutoff=10000):
if usage < usage_cutoff:
continue
- is_diffusers_pipeline = hasattr(diffusers.pipelines, diffusers_object)
- if not is_diffusers_pipeline:
- continue
-
- output.append(diffusers_object)
+ if "Pipeline" in diffusers_object:
+ output.append(diffusers_object)
return output
@@ -74,7 +71,6 @@ def fetch_pipeline_modules_to_test():
test_modules = []
for pipeline_name in pipeline_objects:
module = getattr(diffusers, pipeline_name)
-
test_module = module.__module__.split(".")[-2].strip()
test_modules.append(test_module)
diff --git a/utils/release.py b/utils/release.py
index a0800b99fbeb..758fb70caaca 100644
--- a/utils/release.py
+++ b/utils/release.py
@@ -130,7 +130,7 @@ def pre_release_work(patch=False):
def post_release_work():
- """Do all the necessary post-release steps."""
+ """Do all the necesarry post-release steps."""
# First let's get the current version
current_version = get_version()
dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0"
diff --git a/utils/stale.py b/utils/stale.py
index f9c0af89f5a6..12932f31c243 100644
--- a/utils/stale.py
+++ b/utils/stale.py
@@ -17,7 +17,6 @@
"""
import os
from datetime import datetime as dt
-from datetime import timezone
from github import Github
@@ -44,8 +43,8 @@ def main():
if (
last_comment is not None
and last_comment.user.login == "github-actions[bot]"
- and (dt.now(timezone.utc) - issue.updated_at).days > 7
- and (dt.now(timezone.utc) - issue.created_at).days >= 30
+ and (dt.utcnow() - issue.updated_at).days > 7
+ and (dt.utcnow() - issue.created_at).days >= 30
and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels())
):
# Closes the issue after 7 days of inactivity since the Stalebot notification.
@@ -59,8 +58,8 @@ def main():
issue.edit(state="open")
issue.remove_from_labels("stale")
elif (
- (dt.now(timezone.utc) - issue.updated_at).days > 23
- and (dt.now(timezone.utc) - issue.created_at).days >= 30
+ (dt.utcnow() - issue.updated_at).days > 23
+ and (dt.utcnow() - issue.created_at).days >= 30
and not any(label.name.lower() in LABELS_TO_EXEMPT for label in issue.get_labels())
):
# Post a Stalebot notification after 23 days of inactivity.