From 9195ef76320ab987ad8790363226c18d3398a8e1 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 09:35:20 +0000 Subject: [PATCH 1/4] Add `dynamic_shifting` to SD3 --- .../stable_diffusion_3/pipeline_stable_diffusion_3.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 513f86441c3a..27be1c917a87 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -702,6 +702,7 @@ def __call__( skip_layer_guidance_scale: int = 2.8, skip_layer_guidance_stop: int = 0.2, skip_layer_guidance_start: int = 0.01, + mu: Optional[float] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -802,6 +803,7 @@ def __call__( `skip_guidance_layers` will start. The guidance will be applied to the layers specified in `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is 0.01. + mu (`float`, *optional*): `mu` value used for `dynamic_shifting`. Examples: @@ -883,7 +885,13 @@ def __call__( pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) From 68267bddc67ad823cbee020d1d18a7d2b433284e Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 10:31:00 +0000 Subject: [PATCH 2/4] calculate_shift --- .../pipeline_stable_diffusion_3.py | 50 ++++++++++++++----- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 27be1c917a87..9ce3aa8b0040 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -68,6 +68,20 @@ """ +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -884,18 +898,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) - # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, - num_inference_steps, - device, - sigmas=sigmas, - mu=mu, - ) - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - self._num_timesteps = len(timesteps) - - # 5. Prepare latent variables + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, @@ -908,6 +911,29 @@ def __call__( latents, ) + # 5. Prepare timesteps + if self.scheduler.config.use_dynamic_shifting and mu is None: + _, _, height, width = latents.shape + image_seq_len = (height // self.transformer.config.patch_size) * ( + width // self.transformer.config.patch_size + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): From 75a9febf127c3211f8adddfc8db6b14c50fb2148 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 10:50:00 +0000 Subject: [PATCH 3/4] FlowMatchHeunDiscreteScheduler doesn't support mu --- .../stable_diffusion_3/pipeline_stable_diffusion_3.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 9ce3aa8b0040..0a51dcbc1261 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -912,7 +912,8 @@ def __call__( ) # 5. Prepare timesteps - if self.scheduler.config.use_dynamic_shifting and mu is None: + scheduler_kwargs = {} + if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: _, _, height, width = latents.shape image_seq_len = (height // self.transformer.config.patch_size) * ( width // self.transformer.config.patch_size @@ -924,12 +925,15 @@ def __call__( self.scheduler.config.base_shift, self.scheduler.config.max_shift, ) + scheduler_kwargs["mu"] = mu + elif mu is not None: + scheduler_kwargs["mu"] = mu timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, sigmas=sigmas, - mu=mu, + **scheduler_kwargs, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) From 4820d2a2cc207cb23dcdc3dc9517d44e6846766b Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 16 Dec 2024 11:36:21 +0000 Subject: [PATCH 4/4] Inpaint/img2img --- .../pipeline_stable_diffusion_3_img2img.py | 35 ++++++++++++++++++- .../pipeline_stable_diffusion_3_inpaint.py | 35 ++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 013c31c18e34..c10401324430 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -75,6 +75,20 @@ """ +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -748,6 +762,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, + mu: Optional[float] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -832,6 +847,7 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + mu (`float`, *optional*): `mu` value used for `dynamic_shifting`. Examples: @@ -913,7 +929,24 @@ def __call__( image = self.image_processor.preprocess(image, height=height, width=width) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + scheduler_kwargs = {} + if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: + image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * ( + int(width) // self.vae_scale_factor // self.transformer.config.patch_size + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + scheduler_kwargs["mu"] = mu + elif mu is not None: + scheduler_kwargs["mu"] = mu + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs + ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index 2b6e42aa5081..ca32880d0df2 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -74,6 +74,20 @@ """ +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -838,6 +852,7 @@ def __call__( callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, + mu: Optional[float] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -947,6 +962,7 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + mu (`float`, *optional*): `mu` value used for `dynamic_shifting`. Examples: @@ -1023,7 +1039,24 @@ def __call__( pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 3. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + scheduler_kwargs = {} + if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: + image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * ( + int(width) // self.vae_scale_factor // self.transformer.config.patch_size + ) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + scheduler_kwargs["mu"] = mu + elif mu is not None: + scheduler_kwargs["mu"] = mu + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs + ) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) # check that number of inference steps is not < 1 - as this doesn't make sense if num_inference_steps < 1: