diff --git a/examples/community/README.md b/examples/community/README.md index 47b129ce9e7e..974f77fd1011 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -1320,3 +1320,46 @@ prompt = "photorealistic new zealand hills" image = pipe(prompt, image=input_image, strength=0.75,).images[0] image.save('tensorrt_img2img_new_zealand_hills.png') ``` + +### Stable Diffusion Reference + +This pipeline uses the Reference only Control. Refer to the [sd-webui-controlnet discussion](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236). + + +```py +import torch +from diffusers import UniPCMultistepScheduler +from diffusers.utils import load_image + +input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") + +pipe = StableDiffusionReferencePipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + safety_checker=None, + torch_dtype=torch.float16 + ).to('cuda:0') + +pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + +result_img = pipe(ref_image=input_image, + prompt="1girl", + num_inference_steps=20, + reference_attn=True, + reference_adain=True).images[0] +``` + +Reference Image + +![reference_image](https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png) + +Output Image of `reference_attn=True` and `reference_adain=False` + +![output_image](https://github.com/huggingface/diffusers/assets/24734142/813b5c6a-6d89-46ba-b7a4-2624e240eea5) + +Output Image of `reference_attn=False` and `reference_adain=True` + +![output_image](https://github.com/huggingface/diffusers/assets/24734142/ffc90339-9ef0-4c4d-a544-135c3e5644da) + +Output Image of `reference_attn=True` and `reference_adain=True` + +![output_image](https://github.com/huggingface/diffusers/assets/24734142/3c5255d6-867d-4d35-b202-8dfd30cc6827) diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py new file mode 100644 index 000000000000..5e8051cdcdb2 --- /dev/null +++ b/examples/community/stable_diffusion_reference.py @@ -0,0 +1,774 @@ +# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch + +from diffusers import StableDiffusionPipeline +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.utils import PIL_INTERPOLATION, logging, randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import UniPCMultistepScheduler + >>> from diffusers.utils import load_image + + >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") + + >>> pipe = StableDiffusionReferencePipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + safety_checker=None, + torch_dtype=torch.float16 + ).to('cuda:0') + + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config) + + >>> result_img = pipe(ref_image=input_image, + prompt="1girl", + num_inference_steps=20, + reference_attn=True, + reference_adain=True).images[0] + + >>> result_img.show() + ``` +""" + + +def torch_dfs(model: torch.nn.Module): + result = [model] + for child in model.children(): + result += torch_dfs(child) + return result + + +class StableDiffusionReferencePipeline(StableDiffusionPipeline): + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = (image - 0.5) / 0.5 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): + refimage = refimage.to(device=device, dtype=dtype) + + # encode the mask image into latents space so we can concatenate it to the latents + if isinstance(generator, list): + ref_image_latents = [ + self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + ref_image_latents = torch.cat(ref_image_latents, dim=0) + else: + ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator) + ref_image_latents = self.vae.config.scaling_factor * ref_image_latents + + # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method + if ref_image_latents.shape[0] < batch_size: + if not batch_size % ref_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1) + + ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents + + # aligning device to prevent device errors when concating it with the latent model input + ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) + return ref_image_latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + ref_image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_auto_machine_weight: float = 1.0, + gn_auto_machine_weight: float = 1.0, + style_fidelity: float = 0.5, + reference_attn: bool = True, + reference_adain: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + ref_image (`torch.FloatTensor`, `PIL.Image.Image`): + The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can + also be accepted as an image. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + attention_auto_machine_weight (`float`): + Weight of using reference query for self attention's context. + If attention_auto_machine_weight=1.0, use reference query for all self attention's context. + gn_auto_machine_weight (`float`): + Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins. + style_fidelity (`float`): + style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important, + elif style_fidelity=0.0, prompt more important, else balanced. + reference_attn (`bool`): + Whether to use reference query for self attention's context. + reference_adain (`bool`): + Whether to use reference adain. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True." + + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, ref_image) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Preprocess reference image + ref_image = self.prepare_image( + image=ref_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=prompt_embeds.dtype, + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare reference latent variables + ref_image_latents = self.prepare_ref_latents( + ref_image, + batch_size * num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Modify self attention and group norm + MODE = "write" + uc_mask = ( + torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt) + .type_as(ref_image_latents) + .bool() + ) + + def hacked_basic_transformer_inner_forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + class_labels=None, + ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if self.only_cross_attention: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + else: + if MODE == "write": + self.bank.append(norm_hidden_states.detach().clone()) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if MODE == "read": + if attention_auto_machine_weight > self.attn_weight: + attn_output_uc = self.attn1( + norm_hidden_states, + encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1), + # attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output_c = attn_output_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + attn_output_c[uc_mask] = self.attn1( + norm_hidden_states[uc_mask], + encoder_hidden_states=norm_hidden_states[uc_mask], + **cross_attention_kwargs, + ) + attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc + self.bank.clear() + else: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + def hacked_mid_forward(self, *args, **kwargs): + eps = 1e-6 + x = self.original_forward(*args, **kwargs) + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append(mean) + self.var_bank.append(var) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) + var_acc = sum(self.var_bank) / float(len(self.var_bank)) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + x_uc = (((x - mean) / std) * std_acc) + mean_acc + x_c = x_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + x_c[uc_mask] = x[uc_mask] + x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc + self.mean_bank = [] + self.var_bank = [] + return x + + def hack_CrossAttnDownBlock2D_forward( + self, + hidden_states, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + cross_attention_kwargs=None, + ): + eps = 1e-6 + + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append(mean) + self.var_bank.append(var) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_DownBlock2D_forward(self, hidden_states, temb=None): + eps = 1e-6 + + output_states = () + + for i, resnet in enumerate(self.resnets): + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append(mean) + self.var_bank.append(var) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_CrossAttnUpBlock2D_forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + cross_attention_kwargs=None, + upsample_size=None, + attention_mask=None, + ): + eps = 1e-6 + # TODO(Patrick, William) - attention mask is not used + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append(mean) + self.var_bank.append(var) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + eps = 1e-6 + for i, resnet in enumerate(self.resnets): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append(mean) + self.var_bank.append(var) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + if reference_attn: + attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)] + attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) + + for i, module in enumerate(attn_modules): + module._original_inner_forward = module.forward + module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) + module.bank = [] + module.attn_weight = float(i) / float(len(attn_modules)) + + if reference_adain: + gn_modules = [self.unet.mid_block] + self.unet.mid_block.gn_weight = 0 + + down_blocks = self.unet.down_blocks + for w, module in enumerate(down_blocks): + module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) + gn_modules.append(module) + + up_blocks = self.unet.up_blocks + for w, module in enumerate(up_blocks): + module.gn_weight = float(w) / float(len(up_blocks)) + gn_modules.append(module) + + for i, module in enumerate(gn_modules): + if getattr(module, "original_forward", None) is None: + module.original_forward = module.forward + if i == 0: + # mid_block + module.forward = hacked_mid_forward.__get__(module, torch.nn.Module) + elif isinstance(module, CrossAttnDownBlock2D): + module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D) + elif isinstance(module, DownBlock2D): + module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D) + elif isinstance(module, CrossAttnUpBlock2D): + module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) + elif isinstance(module, UpBlock2D): + module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D) + module.mean_bank = [] + module.var_bank = [] + module.gn_weight *= 2 + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # ref only part + noise = randn_tensor( + ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype + ) + ref_xt = self.scheduler.add_noise( + ref_image_latents, + noise, + t.reshape( + 1, + ), + ) + ref_xt = self.scheduler.scale_model_input(ref_xt, t) + + MODE = "write" + self.unet( + ref_xt, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + ) + + # predict the noise residual + MODE = "read" + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)