Skip to content

Guiders support for Wan #11211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: feature/guiders
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/diffusers/hooks/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
from ..models.transformers.transformer_wan import WanAttnProcessor2_0, WanPAGAttnProcessor2_0, WanTransformerBlock


@dataclass
Expand Down Expand Up @@ -186,6 +186,14 @@ def _register_guidance_metadata():
),
)

# Wan
GuidanceMetadataRegistry.register(
model_class=WanAttnProcessor2_0,
metadata=GuidanceMetadata(
perturbed_attention_guidance_processor_cls=WanPAGAttnProcessor2_0,
),
)


# fmt: off
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
Expand Down
82 changes: 82 additions & 0 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,85 @@ def forward(
return (output,)

return Transformer2DModelOutput(sample=output)


### ===== Custom attention processors for guidance methods ===== ###


class WanPAGAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
is_encoder_hidden_states_provided = encoder_hidden_states is not None
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
encoder_hidden_states_img = encoder_hidden_states[:, :257]
encoder_hidden_states = encoder_hidden_states[:, 257:]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)

if rotary_emb is not None:

def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)

query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)

# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img = attn.add_k_proj(encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)

key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)

hidden_states_img = F.scaled_dot_product_attention(
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)

if is_encoder_hidden_states_provided:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
else:
# Perturbed attention applied only when self-attention
hidden_states = value

hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)

if hidden_states_img is not None:
hidden_states = hidden_states + hidden_states_img

hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/cogview4/pipeline_cogview4.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def __call__(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left]
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds]
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[c] for c in conds]

with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
Expand Down
55 changes: 32 additions & 23 deletions src/diffusers/pipelines/wan/pipeline_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from transformers import AutoTokenizer, UMT5EncoderModel

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...guiders import ClassifierFreeGuidance, GuidanceMixin, _raise_guidance_deprecation_warning
from ...loaders import WanLoraLoaderMixin
from ...models import AutoencoderKLWan, WanTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
Expand Down Expand Up @@ -380,6 +381,7 @@ def __call__(
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
guidance: Optional[GuidanceMixin] = None,
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -444,6 +446,10 @@ def __call__(
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""

_raise_guidance_deprecation_warning(guidance_scale=guidance_scale)
if guidance is None:
guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale)

if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

Expand Down Expand Up @@ -519,37 +525,38 @@ def __call__(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)

conds = [prompt_embeds, negative_prompt_embeds]
prompt_embeds, negative_prompt_embeds = [[c] for c in conds]

with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
self._current_timestep = t
if self.interrupt:
continue

self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])

cc.mark_state("cond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]

if self.do_classifier_free_guidance:
cc.mark_state("uncond")
noise_uncond = self.transformer(
hidden_states=latent_model_input,
guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
guidance.prepare_models(self.transformer)
latents, prompt_embeds = guidance.prepare_inputs(
latents, (prompt_embeds[0], negative_prompt_embeds[0])
)

for batch_index, (latent, condition) in enumerate(zip(latents, prompt_embeds)):
cc.mark_state(f"batch_{batch_index}")
latent = latent.to(transformer_dtype)
timestep = t.expand(latent.shape[0])
noise_pred = self.transformer(
hidden_states=latent,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=condition,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
guidance.prepare_outputs(noise_pred)
Copy link
Collaborator

@yiyixuxu yiyixuxu Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so first, let's test very very thorougly on the potentially performance difference on this change (only need to SDXL for now, different num_images_per_prompt, machine type, etc)

second, code-wise I think it's less confusing with something like this, i.e. explicitly pass the model as input (otherwise it's unclear there is a model call there), and a function should always return an output if it modify input

noise_pred = guider.prepare_cond( self.transformer, ...)
outputs = guider.prepare_guider_output( self.transformer, ....)


# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
outputs = guidance.outputs
noise_pred = guidance(**outputs)
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
guidance.cleanup_models(self.transformer)

if callback_on_step_end is not None:
callback_kwargs = {}
Expand All @@ -558,8 +565,10 @@ def __call__(
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
prompt_embeds = [callback_outputs.pop("prompt_embeds", prompt_embeds[0])]
negative_prompt_embeds = [
callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0])
]

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down
Loading