Skip to content

Attention mask for Flux & SD3 #10044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ def __call__(
) -> torch.FloatTensor:
residual = hidden_states

batch_size = hidden_states.shape[0]
batch_size, sequence_length, _ = hidden_states.shape

# `sample` projections.
query = attn.to_q(hidden_states)
Expand Down Expand Up @@ -1129,11 +1129,27 @@ def __call__(
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)

query_pad_size = query.size(2)
key_pad_size = key.size(2)

query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
if attention_mask is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well I think what we want is not to have a specific implementation to apply attention_mask for flux, is just to allow it to pass down all the way from pipeline, to transformer and the to attention processor so user can experiment with a custom attention mask

cc @christopher5106 is what I described here something you have in mind?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if you do have an specific implementation that we want to add to diffusers, maybe you can run some experiments to help us decide if it's meaningful

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's the encoder attention mask though just following convention of pixart and other DiT that rely on attention masking. masking the attention arbitrarily doesn't unlock new use cases, does it? if so, providing examples of those would be nice.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rootonchair yes feel free to modify pipeline/model to test, and provide the experiments results to us:)

how should I test this feature? modify the original flux pipeline?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @christopher5106 - would you be able to provide a use case? since it was the original ask

masking the attention arbitrarily doesn't unlock new use cases, does it? if so, providing examples of those would be nice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rootonchair yes feel free to modify pipeline/model to test, and provide the experiments results to us:)

how should I test this feature? modify the original flux pipeline?

Sure, perhaps the simplest test would be passing a padded prompt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking the softmax scores for padded positions.

padding_shape = (attention_mask.shape[0], query_pad_size)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
query_attention_mask = torch.cat([padding, attention_mask], dim=2).unsqueeze(2) # N, Iq + Tq, 1

padding_shape = (attention_mask.shape[0], key_pad_size)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
key_attention_mask = torch.cat([padding, attention_mask], dim=2).unsqueeze(1) # N, 1, Ik + Tk

attention_mask = torch.bmm(query_attention_mask, key_attention_mask)

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down Expand Up @@ -1896,18 +1912,34 @@ def __call__(
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)

query_pad_size = query.size(2)
key_pad_size = key.size(2)

# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)

if attention_mask is not None:
padding_shape = (attention_mask.shape[0], query_pad_size)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
query_attention_mask = torch.cat([padding, attention_mask], dim=2).unsqueeze(2) # N, Iq + Tq, 1

padding_shape = (attention_mask.shape[0], key_pad_size)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
key_attention_mask = torch.cat([padding, attention_mask], dim=2).unsqueeze(1) # N, 1, Ik + Tk

attention_mask = torch.bmm(query_attention_mask, key_attention_mask)

if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb

query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False, attention_mask=attention_mask
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down