From 114c74ce16e51adecc3c1a2b440a9ba3eb374f95 Mon Sep 17 00:00:00 2001 From: ppbrown Date: Sun, 25 May 2025 11:35:52 -0700 Subject: [PATCH] Make SDXL treat the "optional" components as really optional --- .../pipeline_stable_diffusion_xl.py | 300 ++++++++++++++---- 1 file changed, 235 insertions(+), 65 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 737caac51550..6f81eef7967e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +import torch.nn.functional as F # required for padding shorter sequences from transformers import ( CLIPImageProcessor, CLIPTextModel, @@ -371,24 +372,47 @@ def encode_prompt( batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] - text_encoders = ( - [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] - ) + tokenizers = [] + if self.tokenizer is not None: + tokenizers.append(self.tokenizer) + if self.tokenizer_2 is not None: + tokenizers.append(self.tokenizer_2) + if not tokenizers: + raise ValueError( + "Cannot encode prompt since no tokenizer is defined. Make sure that either `tokenizer` or `tokenizer_2` is defined." + ) + + text_encoders = [] + if self.text_encoder is not None: + text_encoders.append(self.text_encoder) + if self.text_encoder_2 is not None: + text_encoders.append(self.text_encoder_2) + if not text_encoders: + raise ValueError( + "Cannot encode prompt since no text encoder is defined. Make sure that either `text_encoder` or `text_encoder_2` is defined." + ) if prompt_embeds is None: - prompt_2 = prompt_2 or prompt + prompt_2 = prompt_2 or prompt # Ensure prompt_2 is set if only one prompt is provided initially prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - # textual inversion: process multi-vector tokens if necessary + prompts_list = [] + # Determine which prompts to use based on available tokenizers/encoders + if self.tokenizer is not None and self.text_encoder is not None: + prompts_list.append(prompt) + if self.tokenizer_2 is not None and self.text_encoder_2 is not None: + # If only tokenizer_2 is available, it should use the first prompt + prompts_list.append( + prompt_2 if (self.tokenizer is not None and self.text_encoder is not None) else prompt + ) + prompt_embeds_list = [] - prompts = [prompt, prompt_2] - for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + for current_prompt, tokenizer, text_encoder in zip(prompts_list, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, tokenizer) + current_prompt = self.maybe_convert_prompt(current_prompt, tokenizer) text_inputs = tokenizer( - prompt, + current_prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, @@ -396,7 +420,7 @@ def encode_prompt( ) text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = tokenizer(current_prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids @@ -407,84 +431,143 @@ def encode_prompt( f" {tokenizer.model_max_length} tokens: {removed_text}" ) - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + current_encoder_output = text_encoder(text_input_ids.to(device), output_hidden_states=True) - # We are only ALWAYS interested in the pooled output of the final text encoder - if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: - pooled_prompt_embeds = prompt_embeds[0] + # Pooled output taken from the last text encoder + if text_encoder == text_encoders[-1]: # Check if current encoder is the last one + # Ensure current_encoder_output[0] is the pooled output, typically ndim == 2 + if pooled_prompt_embeds is None and hasattr( + current_encoder_output, "pooler_output" + ) and current_encoder_output.pooler_output is not None: + pooled_prompt_embeds = current_encoder_output.pooler_output + elif pooled_prompt_embeds is None and current_encoder_output[0].ndim == 2: # Fallback + pooled_prompt_embeds = current_encoder_output[0] if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] + current_hidden_states = current_encoder_output.hidden_states[-2] else: # "2" because SDXL always indexes from the penultimate layer. - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + current_hidden_states = current_encoder_output.hidden_states[-(clip_skip + 2)] - prompt_embeds_list.append(prompt_embeds) + prompt_embeds_list.append(current_hidden_states) - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + # ensure all hidden-state sequences have equal length before concatenation + max_seq_len = max(pe.shape[1] for pe in prompt_embeds_list) + if any(pe.shape[1] != max_seq_len for pe in prompt_embeds_list): + prompt_embeds_list = [ + F.pad(pe, (0, 0, 0, max_seq_len - pe.shape[1])) if pe.shape[1] < max_seq_len else pe + for pe in prompt_embeds_list + ] + + prompt_embeds = torch.cat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + if pooled_prompt_embeds is not None: + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + # If pooled_prompt_embeds is None, negative_pooled_prompt_embeds will be handled later or error elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_2 = negative_prompt_2 or negative_prompt # Ensure negative_prompt_2 is set - # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt negative_prompt_2 = ( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." + uncond_tokens_list = [] + if self.tokenizer is not None and self.text_encoder is not None: + uncond_tokens_list.append(negative_prompt) + if self.tokenizer_2 is not None and self.text_encoder_2 is not None: + uncond_tokens_list.append( + negative_prompt_2 if (self.tokenizer is not None and self.text_encoder is not None) else negative_prompt ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] + + if prompt is not None: # Check if prompt is not None before type comparison + for neg_prompt_tokens_segment in uncond_tokens_list: + if batch_size != len(neg_prompt_tokens_segment): + raise ValueError( + f"`negative_prompt` segment has batch size {len(neg_prompt_tokens_segment)}, but `prompt`:" + f" has batch size {batch_size}. Please make sure that passed `negative_prompt` segments match" + " the batch size of `prompt`." + ) negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + for current_negative_prompt, tokenizer, text_encoder in zip( + uncond_tokens_list, tokenizers, text_encoders + ): if isinstance(self, TextualInversionLoaderMixin): - negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + current_negative_prompt = self.maybe_convert_prompt(current_negative_prompt, tokenizer) - max_length = prompt_embeds.shape[1] + max_length = tokenizer.model_max_length uncond_input = tokenizer( - negative_prompt, + current_negative_prompt, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) - negative_prompt_embeds = text_encoder( + current_neg_encoder_output = text_encoder( uncond_input.input_ids.to(device), output_hidden_states=True, ) - # We are only ALWAYS interested in the pooled output of the final text encoder - if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - + # Pooled output taken from the last text encoder + if text_encoder == text_encoders[-1]: # Check if current encoder is the last one + if negative_pooled_prompt_embeds is None and hasattr( + current_neg_encoder_output, "pooler_output" + ) and current_neg_encoder_output.pooler_output is not None: + negative_pooled_prompt_embeds = current_neg_encoder_output.pooler_output + elif negative_pooled_prompt_embeds is None and current_neg_encoder_output[0].ndim == 2: # Fallback + negative_pooled_prompt_embeds = current_neg_encoder_output[0] + + current_negative_hidden_states = current_neg_encoder_output.hidden_states[-2] + negative_prompt_embeds_list.append(current_negative_hidden_states) + + # ensure all negative hidden-state sequences have equal length before concatenation + max_seq_len_neg = max(ne.shape[1] for ne in negative_prompt_embeds_list) + if any(ne.shape[1] != max_seq_len_neg for ne in negative_prompt_embeds_list): + negative_prompt_embeds_list = [ + F.pad(ne, (0, 0, 0, max_seq_len_neg - ne.shape[1])) if ne.shape[1] < max_seq_len_neg else ne + for ne in negative_prompt_embeds_list + ] + + negative_prompt_embeds = torch.cat(negative_prompt_embeds_list, dim=-1) + + # --------------------------------------------------------------- + # Ensure prompt embeddings match the UNet encoder hidden dim + # (needed if text_encoder_2 is disabled, otherwise UNet expects + # 2048-wide vectors but we only have 768). + # --------------------------------------------------------------- + + expected_encoder_dim = 2048 + + def _match_hid_dim(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if t is None: + return None + if t.shape[-1] == expected_encoder_dim: + return t + if t.shape[-1] < expected_encoder_dim: + pad = expected_encoder_dim - t.shape[-1] + return F.pad(t, (0, pad)) + return t[..., :expected_encoder_dim] + + prompt_embeds = _match_hid_dim(prompt_embeds) + negative_prompt_embeds = _match_hid_dim(negative_prompt_embeds) + + # Determine dtype for prompt_embeds if self.text_encoder_2 is not None: - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + target_dtype = self.text_encoder_2.dtype + elif self.text_encoder is not None: + target_dtype = self.text_encoder.dtype else: - prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + # This case should ideally be prevented by the checks at the start of the function + target_dtype = self.unet.dtype + + prompt_embeds = prompt_embeds.to(dtype=target_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -495,22 +578,89 @@ def encode_prompt( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - if self.text_encoder_2 is not None: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) - else: - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + # Determine dtype for negative_prompt_embeds (should be same as prompt_embeds) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=target_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_embeds = ( + negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + ) - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - if do_classifier_free_guidance: - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + # Check if pooled_prompt_embeds were generated, especially if prompts were provided. + if ( + prompt_embeds is not None and pooled_prompt_embeds is None + ): # prompt_embeds is the final concatenated embeddings + raise ValueError( + "Pooled prompt embeddings were not generated. Make sure the model has a pooling layer or outputs pooler_output." + ) + + if pooled_prompt_embeds is not None: + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) + if do_classifier_free_guidance: + # Similar check for negative_pooled_prompt_embeds + if ( + negative_prompt_embeds is not None + and negative_pooled_prompt_embeds is None + and not zero_out_negative_prompt + ): # negative_prompt_embeds is the final concatenated embeddings + raise ValueError( + "Negative pooled prompt embeddings were not generated but were expected. Make sure the model has a pooling layer or outputs pooler_output." + ) + + if negative_pooled_prompt_embeds is not None: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( + 1, num_images_per_prompt + ).view(bs_embed * num_images_per_prompt, -1) + elif zero_out_negative_prompt and pooled_prompt_embeds is not None: # If it was meant to be zeros + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif zero_out_negative_prompt and pooled_prompt_embeds is None: # If zero_out and main pooled is None + if prompt_embeds is not None: # We need a shape reference + # Attempt to get a reference shape for text_encoder_2's projection if available + ref_shape_dim = ( + self.text_encoder_2.config.projection_dim + if self.text_encoder_2 + else ( + self.text_encoder.config.projection_dim + if self.text_encoder and hasattr(self.text_encoder.config, "projection_dim") + else None + ) + ) + if ( + ref_shape_dim is None + and hasattr(self.text_encoder, "text_projection") + and self.text_encoder.text_projection is not None + ): # For OpenCLIP-H/14 + ref_shape_dim = self.text_encoder.text_projection.shape[-1] + if ( + ref_shape_dim is None + and hasattr(self.text_encoder, "projection_dim") + and self.text_encoder.projection_dim is not None + ): # For CLIP-G/14 + ref_shape_dim = self.text_encoder.projection_dim + + if ref_shape_dim is not None: + negative_pooled_prompt_embeds = torch.zeros( + bs_embed, ref_shape_dim, device=device, dtype=target_dtype + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( + 1, num_images_per_prompt + ).view(bs_embed * num_images_per_prompt, -1) + else: # Fallback if no specific projection_dim found + logger.warning( + "Cannot determine the projection dimension for zeroed negative_pooled_prompt_embeds. This might lead to errors." + ) + last_encoder = text_encoders[-1] + fallback_dim = last_encoder.config.hidden_size + negative_pooled_prompt_embeds = torch.zeros( + bs_embed, fallback_dim, device=device, dtype=target_dtype + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( + 1, num_images_per_prompt + ).view(bs_embed * num_images_per_prompt, -1) + if self.text_encoder is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers @@ -1131,10 +1281,30 @@ def __call__( # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds - if self.text_encoder_2 is None: - text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) - else: - text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + # ---------------------------------------------------------------- + # Ensure pooled embeddings match UNet-expected projection dimension + # to avoid mismatched 2816 vs 2304 vectors. + # ---------------------------------------------------------------- + required_proj_dim = ( + self.unet.add_embedding.linear_1.in_features + - self.unet.config.addition_time_embed_dim * 6 # 6 = len(original+crop+target) + ) + + def _match_dim(t): + if t is None: + return None + if t.shape[-1] == required_proj_dim: + return t + if t.shape[-1] < required_proj_dim: + pad = required_proj_dim - t.shape[-1] + return F.pad(t, (0, pad)) + return t[..., :required_proj_dim] + + pooled_prompt_embeds = _match_dim(pooled_prompt_embeds) + negative_pooled_prompt_embeds = _match_dim(negative_pooled_prompt_embeds) + add_text_embeds = pooled_prompt_embeds + text_encoder_projection_dim = required_proj_dim add_time_ids = self._get_add_time_ids( original_size,