From 74e34e5f699b2d199e57e2bc297b9fea66a86ecb Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 5 Apr 2025 00:09:09 +0200 Subject: [PATCH 1/4] guiders support for wan --- src/diffusers/hooks/_helpers.py | 10 ++- .../models/transformers/transformer_wan.py | 82 +++++++++++++++++++ .../pipelines/cogview4/pipeline_cogview4.py | 2 +- src/diffusers/pipelines/wan/pipeline_wan.py | 55 +++++++------ 4 files changed, 124 insertions(+), 25 deletions(-) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index c87468001e1f..070234e0946c 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -30,7 +30,7 @@ ) from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock from ..models.transformers.transformer_mochi import MochiTransformerBlock -from ..models.transformers.transformer_wan import WanTransformerBlock +from ..models.transformers.transformer_wan import WanAttnProcessor2_0, WanPAGAttnProcessor2_0, WanTransformerBlock @dataclass @@ -186,6 +186,14 @@ def _register_guidance_metadata(): ), ) + # Wan + GuidanceMetadataRegistry.register( + model_class=WanAttnProcessor2_0, + metadata=GuidanceMetadata( + perturbed_attention_guidance_processor_cls=WanPAGAttnProcessor2_0, + ), + ) + # fmt: off def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index aa03e97093aa..703e2f82c24e 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -467,3 +467,85 @@ def forward( return (output,) return Transformer2DModelOutput(sample=output) + + +### ===== Custom attention processors for guidance methods ===== ### + + +class WanPAGAttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + is_encoder_hidden_states_provided = encoder_hidden_states is not None + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + encoder_hidden_states_img = encoder_hidden_states[:, :257] + encoder_hidden_states = encoder_hidden_states[:, 257:] + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if rotary_emb is not None: + + def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) + x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) + return x_out.type_as(hidden_states) + + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img = attn.add_k_proj(encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + hidden_states_img = F.scaled_dot_product_attention( + query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + ) + hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + if is_encoder_hidden_states_provided: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + else: + # Perturbed attention applied only when self-attention + hidden_states = value + + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 5d7e09c389ae..e7e0b50f0e7a 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -617,7 +617,7 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left] - prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds] + prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[c] for c in conds] with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 54164437e7f2..5c1b42d2f9e6 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -21,6 +21,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...guiders import ClassifierFreeGuidance, GuidanceMixin, _raise_guidance_deprecation_warning from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, WanTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -380,6 +381,7 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + guidance: Optional[GuidanceMixin] = None, ): r""" The call function to the pipeline for generation. @@ -444,6 +446,10 @@ def __call__( indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + _raise_guidance_deprecation_warning(guidance_scale=guidance_scale) + if guidance is None: + guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale) + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -519,37 +525,38 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) + conds = [prompt_embeds, negative_prompt_embeds] + prompt_embeds, negative_prompt_embeds = [[c] for c in conds] + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): + self._current_timestep = t if self.interrupt: continue - self._current_timestep = t - latent_model_input = latents.to(transformer_dtype) - timestep = t.expand(latents.shape[0]) - - cc.mark_state("cond") - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - cc.mark_state("uncond") - noise_uncond = self.transformer( - hidden_states=latent_model_input, + guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + guidance.prepare_models(self.transformer) + latents, prompt_embeds = guidance.prepare_inputs( + latents, (prompt_embeds[0], negative_prompt_embeds[0]) + ) + + for batch_index, (latent, condition) in enumerate(zip(latents, prompt_embeds)): + cc.mark_state(f"batch_{batch_index}") + latent = latent.to(transformer_dtype) + timestep = t.expand(latent.shape[0]) + noise_pred = self.transformer( + hidden_states=latent, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=condition, attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + guidance.prepare_outputs(noise_pred) - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + outputs = guidance.outputs + noise_pred = guidance(**outputs) + latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0] + guidance.cleanup_models(self.transformer) if callback_on_step_end is not None: callback_kwargs = {} @@ -558,8 +565,10 @@ def __call__( callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds = [callback_outputs.pop("prompt_embeds", prompt_embeds[0])] + negative_prompt_embeds = [ + callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0]) + ] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): From 1411b33563899125bb4600874d49dbb99e18eb9e Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 5 Apr 2025 03:41:25 +0530 Subject: [PATCH 2/4] Update src/diffusers/models/transformers/transformer_wan.py --- src/diffusers/models/transformers/transformer_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 703e2f82c24e..cf9decfc9ed8 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -537,7 +537,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) else: - # Perturbed attention applied only when self-attention + # Perturbed attention applied only to self-attention path hidden_states = value hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) From 98fdabde9e7ade26b65560a602131698f8ba525d Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 5 Apr 2025 10:59:23 +0200 Subject: [PATCH 3/4] update --- src/diffusers/hooks/_helpers.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index f3373af3f134..9dabc7b286b5 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -30,7 +30,7 @@ ) from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock from ..models.transformers.transformer_mochi import MochiTransformerBlock -from ..models.transformers.transformer_wan import WanAttnProcessor2_0, WanPAGAttnProcessor2_0, WanTransformerBlock +from ..models.transformers.transformer_wan import WanTransformerBlock @dataclass @@ -229,14 +229,6 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, * _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states - # Wan - GuidanceMetadataRegistry.register( - model_class=WanAttnProcessor2_0, - metadata=GuidanceMetadata( - perturbed_attention_guidance_processor_cls=WanPAGAttnProcessor2_0, - ), - ) - def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs): hidden_states = kwargs.get("hidden_states", None) From 0147a6eb27e4d3b19bc2c99372632684841ccc1b Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 6 Apr 2025 04:41:28 +0200 Subject: [PATCH 4/4] spatiotemporal guidance: additional wan registrations for attention and attention score skipping --- src/diffusers/hooks/_helpers.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 9dabc7b286b5..d885b7326e05 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -30,7 +30,7 @@ ) from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock from ..models.transformers.transformer_mochi import MochiTransformerBlock -from ..models.transformers.transformer_wan import WanTransformerBlock +from ..models.transformers.transformer_wan import WanAttnProcessor2_0, WanPAGAttnProcessor2_0, WanTransformerBlock @dataclass @@ -101,6 +101,14 @@ def _register_attention_processors_metadata(): ), ) + # Wan + AttentionProcessorRegistry.register( + model_class=WanAttnProcessor2_0, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor, + ), + ) + def _register_guidance_metadata(): # CogView4 @@ -111,6 +119,14 @@ def _register_guidance_metadata(): ), ) + # Wan + GuidanceMetadataRegistry.register( + model_class=WanAttnProcessor2_0, + metadata=GuidanceMetadata( + perturbed_attention_guidance_processor_cls=WanPAGAttnProcessor2_0, + ), + ) + def _register_transformer_blocks_metadata(): # CogVideoX @@ -217,6 +233,13 @@ def _register_transformer_blocks_metadata(): # fmt: off +def _skip_attention___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): hidden_states = kwargs.get("hidden_states", None) encoder_hidden_states = kwargs.get("encoder_hidden_states", None) @@ -228,6 +251,7 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, * _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states +_skip_proc_output_fn_Attention_WanAttnProcessor = _skip_attention___ret___hidden_states def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):