Skip to content

Refactor cross attention and allow mechanism to tweak cross attention function #1639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 52 additions & 150 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from dataclasses import dataclass
from typing import Optional

Expand All @@ -25,6 +24,7 @@
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput
from ..utils.import_utils import is_xformers_available
from .cross_attention_processors import CrossAttentionProcMixin, CrossAttentionProc, XFormersCrossAttentionProc, SlicedAttentionProc


@dataclass
Expand Down Expand Up @@ -176,7 +176,7 @@ def __init__(
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)

def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, cross_attention_kwargs=None, return_dict: bool = True):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
Expand Down Expand Up @@ -214,7 +214,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu

# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs)

# 3. Output
if self.is_input_continuous:
Expand Down Expand Up @@ -448,49 +448,23 @@ def __init__(
self.norm3 = nn.LayerNorm(dim)

# if xformers is installed try to use memory_efficient_attention by default
if is_xformers_available():
try:
self.set_use_memory_efficient_attention_xformers(True)
except Exception as e:
warnings.warn(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers

def forward(self, hidden_states, context=None, timestep=None):
# if is_xformers_available():
# try:
# self.set_use_memory_efficient_attention_xformers(True)
# except Exception as e:
# warnings.warn(
# "Could not enable memory efficient attention. Make sure xformers is installed"
# f" correctly and a GPU is available: {e}"
# )

def forward(self, hidden_states, context=None, timestep=None, cross_attention_kwargs=None):
# 1. Self-Attention
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)

if self.only_cross_attention:
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
hidden_states = self.attn1(norm_hidden_states, context=context, cross_attention_kwargs=cross_attention_kwargs) + hidden_states
else:
hidden_states = self.attn1(norm_hidden_states) + hidden_states

Expand All @@ -499,7 +473,7 @@ def forward(self, hidden_states, context=None, timestep=None):
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
hidden_states = self.attn2(norm_hidden_states, context=context, cross_attention_kwargs=cross_attention_kwargs) + hidden_states

# 3. Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
Expand Down Expand Up @@ -544,7 +518,6 @@ def __init__(
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self._slice_size = None
self._use_memory_efficient_attention_xformers = False

self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
Expand All @@ -554,127 +527,56 @@ def __init__(
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))

def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
self.attn_proc = CrossAttentionProc(self.heads, self.upcast_attention)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could also add this directly in the BasicTransformerBlock and so that the user has to go one less level deep.


def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn_fn = XFormersCrossAttentionProc(self.heads, self.upcast_attention)

def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")

self._slice_size = slice_size

def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, _ = hidden_states.shape

query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
key = self.to_k(context)
value = self.to_v(context)

dim = query.shape[-1]

query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)

# TODO(PVP) - mask is currently never used. Remember to re-implement when used

# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)

self.attn_proc = SlicedAttentionProc(self.heads, self.upcast_attention)

def set_cross_attn_proc(self, attn_proc: CrossAttentionProcMixin):
if not isinstance(attn_proc, CrossAttentionProcMixin):
subclass = attn_proc.__bases__ if hasattr(attn_proc, "__bases__") else None
raise ValueError(f"`attn_proc` should be a subclass of {CrossAttentionProc}, but is of type {type(attn_proc)} and a subclass of {subclass}.")
self.attn_proc = attn_proc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bit of a python anti-pattern here to manually check the baseclasses include the mixin, no?

Since the transformer class doesn't use any of the mixin internals, maybe just change to type signature to Callable and remove the manual type check?


def forward(self, hidden_states, context=None, cross_attention_kwargs=None):
# attn
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whole forward function should probably become the attn_proc so that LORA can be nicely supported as well.

cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
hidden_states = self.attn_proc(hidden_states, self.to_q, self.to_k, self.to_v, context=context, **cross_attention_kwargs)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states

def _attention(self, query, key, value):
if self.upcast_attention:
query = query.float()
key = key.float()

attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)

# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)

# compute attention output
hidden_states = torch.bmm(attention_probs, value)

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states

def _sliced_attention(self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size

query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]

if self.upcast_attention:
query_slice = query_slice.float()
key_slice = key_slice.float()

attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
query_slice,
key_slice.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attn_slice = attn_slice.softmax(dim=-1)

# cast back to the original dtype
attn_slice = attn_slice.to(value.dtype)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])

hidden_states[start_idx:end_idx] = attn_slice

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states

def _memory_efficient_attention_xformers(self, query, key, value):
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states


class FeedForward(nn.Module):
r"""
Expand Down
Loading