Skip to content

T5Attention support for cross-attention #2654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 15, 2023
Merged
56 changes: 37 additions & 19 deletions src/diffusers/models/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(
cross_attention_norm: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
processor: Optional["AttnProcessor"] = None,
):
super().__init__()
Expand All @@ -68,7 +70,7 @@ def __init__(
self.upcast_softmax = upcast_softmax
self.cross_attention_norm = cross_attention_norm

self.scale = dim_head**-0.5
self.scale = dim_head**-0.5 if scale_qk else 1.0

self.heads = heads
# for slice_size > 0 the attention score computation
Expand All @@ -95,14 +97,17 @@ def __init__(
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)

self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))

# set attention processor
# We use the AttnProcessor2_0 by default when torch2.x is used which uses
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
if processor is None:
processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else CrossAttnProcessor()
)
self.set_processor(processor)

def set_use_memory_efficient_attention_xformers(
Expand Down Expand Up @@ -295,7 +300,9 @@ def __call__(
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)

Expand Down Expand Up @@ -362,7 +369,9 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
Expand Down Expand Up @@ -435,7 +444,9 @@ def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op

def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

Expand All @@ -454,7 +465,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
value = attn.head_to_batch_dim(value).contiguous()

hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
Expand All @@ -472,7 +483,10 @@ def __init__(self):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, inner_dim = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
Expand All @@ -496,6 +510,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
Expand Down Expand Up @@ -527,7 +542,9 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
Expand All @@ -542,7 +559,7 @@ def __call__(
value = attn.head_to_batch_dim(value).contiguous()

hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = attn.batch_to_head_dim(hidden_states)

Expand All @@ -559,8 +576,9 @@ def __init__(self, slice_size):
self.slice_size = slice_size

def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

query = attn.to_q(hidden_states)
Expand All @@ -577,12 +595,12 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

batch_size_attention = query.shape[0]
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)

for i in range(hidden_states.shape[0] // self.slice_size):
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size

Expand Down Expand Up @@ -638,12 +656,12 @@ def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)

batch_size_attention = query.shape[0]
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)

for i in range(hidden_states.shape[0] // self.slice_size):
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size

Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def test_xformers_enable_works(self):
model.enable_xformers_memory_efficient_attention()

assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersCrossAttnProcessor"
), "xformers is not enabled"

@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
Expand Down