Skip to content

T5Attention support for cross-attention #2654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 15, 2023
Merged
57 changes: 37 additions & 20 deletions src/diffusers/models/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(
cross_attention_norm: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
processor: Optional["AttnProcessor"] = None,
):
super().__init__()
Expand All @@ -68,7 +70,7 @@ def __init__(
self.upcast_softmax = upcast_softmax
self.cross_attention_norm = cross_attention_norm

self.scale = dim_head**-0.5
self.scale = dim_head**-0.5 if scale_qk else 1.0

self.heads = heads
# for slice_size > 0 the attention score computation
Expand All @@ -95,14 +97,18 @@ def __init__(
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)

self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))

# set attention processor
# We use the AttnProcessor2_0 by default when torch2.x is used which uses
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the `scale` argument
if processor is None:
processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
if torch.torch_version.TorchVersion(torch.__version__) >= (2, 1, 0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's revert this, we don't need 2.1, 2.0 is enough and I think the logic before was good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right but then the scaled_dot_product_attention in 2.0 has no scale which is what i would need... but yes i can deal with that in the pipeline?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, ok I think it's fine if Torch 2.0 doesn't work yet for the spectrogram model. Let's maybe just advertise it with the previous PyTorch version and see if the community tries it out on Pytorch 2.0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok cool! reverting... i can deal with it or i can also check if attn.scale == 1 and not do this... which is only for spectrogram for now?

processor = AttnProcessor2_0()
else:
processor = CrossAttnProcessor()
self.set_processor(processor)

def set_use_memory_efficient_attention_xformers(
Expand Down Expand Up @@ -295,7 +301,9 @@ def __call__(
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)

Expand Down Expand Up @@ -362,7 +370,9 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
Expand Down Expand Up @@ -435,7 +445,9 @@ def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op

def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

Expand All @@ -454,7 +466,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
value = attn.head_to_batch_dim(value).contiguous()

hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
Expand All @@ -472,7 +484,9 @@ def __init__(self):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, inner_dim = hidden_states.shape
batch_size, sequence_length, inner_dim = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
Expand All @@ -497,7 +511,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No

# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=attn.scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this backwards compatible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes since if scale=None, the default scale is used ie. the 1/sqrt(D) but only works in 2.1 nightly

)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
Expand Down Expand Up @@ -527,7 +541,9 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
Expand All @@ -542,7 +558,7 @@ def __call__(
value = attn.head_to_batch_dim(value).contiguous()

hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = attn.batch_to_head_dim(hidden_states)

Expand All @@ -559,8 +575,9 @@ def __init__(self, slice_size):
self.slice_size = slice_size

def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

query = attn.to_q(hidden_states)
Expand All @@ -577,12 +594,12 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

batch_size_attention = query.shape[0]
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)

for i in range(hidden_states.shape[0] // self.slice_size):
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size

Expand Down Expand Up @@ -638,12 +655,12 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)

batch_size_attention = query.shape[0]
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)

for i in range(hidden_states.shape[0] // self.slice_size):
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size

Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def test_xformers_enable_works(self):
model.enable_xformers_memory_efficient_attention()

assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersCrossAttnProcessor"
), "xformers is not enabled"

@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
Expand Down