-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Refactor cross attention and allow mechanism to tweak cross attention function #1639
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
bb3721a
2cf2902
4e98158
92a0d01
464bca8
e1f4636
75035d2
4d72931
c2a0b4d
a4a2b93
c5e7d9e
6b88650
9c32a36
a13d2a8
5db574e
9b449a8
0e987e1
fbfc842
244020c
8865b18
8b800da
9cb7ee6
9d5e5ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,6 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import math | ||
import warnings | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
|
@@ -25,6 +24,7 @@ | |
from ..models.embeddings import ImagePositionalEmbeddings | ||
from ..utils import BaseOutput | ||
from ..utils.import_utils import is_xformers_available | ||
from .cross_attention_processors import CrossAttentionProcMixin, CrossAttentionProc, XFormersCrossAttentionProc, SlicedAttentionProc | ||
|
||
|
||
@dataclass | ||
|
@@ -176,7 +176,7 @@ def __init__( | |
self.norm_out = nn.LayerNorm(inner_dim) | ||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) | ||
|
||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): | ||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, cross_attention_kwargs=None, return_dict: bool = True): | ||
""" | ||
Args: | ||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. | ||
|
@@ -214,7 +214,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu | |
|
||
# 2. Blocks | ||
for block in self.transformer_blocks: | ||
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) | ||
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs) | ||
|
||
# 3. Output | ||
if self.is_input_continuous: | ||
|
@@ -448,49 +448,23 @@ def __init__( | |
self.norm3 = nn.LayerNorm(dim) | ||
|
||
# if xformers is installed try to use memory_efficient_attention by default | ||
if is_xformers_available(): | ||
try: | ||
self.set_use_memory_efficient_attention_xformers(True) | ||
except Exception as e: | ||
warnings.warn( | ||
"Could not enable memory efficient attention. Make sure xformers is installed" | ||
f" correctly and a GPU is available: {e}" | ||
) | ||
|
||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): | ||
if not is_xformers_available(): | ||
print("Here is how to install it") | ||
raise ModuleNotFoundError( | ||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install" | ||
" xformers", | ||
name="xformers", | ||
) | ||
elif not torch.cuda.is_available(): | ||
raise ValueError( | ||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" | ||
" available for GPU " | ||
) | ||
else: | ||
try: | ||
# Make sure we can run the memory efficient attention | ||
_ = xformers.ops.memory_efficient_attention( | ||
torch.randn((1, 2, 40), device="cuda"), | ||
torch.randn((1, 2, 40), device="cuda"), | ||
torch.randn((1, 2, 40), device="cuda"), | ||
) | ||
except Exception as e: | ||
raise e | ||
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers | ||
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers | ||
|
||
def forward(self, hidden_states, context=None, timestep=None): | ||
# if is_xformers_available(): | ||
# try: | ||
# self.set_use_memory_efficient_attention_xformers(True) | ||
# except Exception as e: | ||
# warnings.warn( | ||
# "Could not enable memory efficient attention. Make sure xformers is installed" | ||
# f" correctly and a GPU is available: {e}" | ||
# ) | ||
|
||
def forward(self, hidden_states, context=None, timestep=None, cross_attention_kwargs=None): | ||
# 1. Self-Attention | ||
norm_hidden_states = ( | ||
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) | ||
) | ||
|
||
if self.only_cross_attention: | ||
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states | ||
hidden_states = self.attn1(norm_hidden_states, context=context, cross_attention_kwargs=cross_attention_kwargs) + hidden_states | ||
else: | ||
hidden_states = self.attn1(norm_hidden_states) + hidden_states | ||
|
||
|
@@ -499,7 +473,7 @@ def forward(self, hidden_states, context=None, timestep=None): | |
norm_hidden_states = ( | ||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) | ||
) | ||
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states | ||
hidden_states = self.attn2(norm_hidden_states, context=context, cross_attention_kwargs=cross_attention_kwargs) + hidden_states | ||
|
||
# 3. Feed-forward | ||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states | ||
|
@@ -544,7 +518,6 @@ def __init__( | |
# You can set slice_size with `set_attention_slice` | ||
self.sliceable_head_dim = heads | ||
self._slice_size = None | ||
self._use_memory_efficient_attention_xformers = False | ||
|
||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) | ||
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) | ||
|
@@ -554,127 +527,56 @@ def __init__( | |
self.to_out.append(nn.Linear(inner_dim, query_dim)) | ||
self.to_out.append(nn.Dropout(dropout)) | ||
|
||
def reshape_heads_to_batch_dim(self, tensor): | ||
batch_size, seq_len, dim = tensor.shape | ||
head_size = self.heads | ||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) | ||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) | ||
return tensor | ||
self.attn_proc = CrossAttentionProc(self.heads, self.upcast_attention) | ||
|
||
def reshape_batch_dim_to_heads(self, tensor): | ||
batch_size, seq_len, dim = tensor.shape | ||
head_size = self.heads | ||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) | ||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) | ||
return tensor | ||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): | ||
if not is_xformers_available(): | ||
print("Here is how to install it") | ||
raise ModuleNotFoundError( | ||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install" | ||
" xformers", | ||
name="xformers", | ||
) | ||
elif not torch.cuda.is_available(): | ||
raise ValueError( | ||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" | ||
" available for GPU " | ||
) | ||
else: | ||
try: | ||
# Make sure we can run the memory efficient attention | ||
_ = xformers.ops.memory_efficient_attention( | ||
torch.randn((1, 2, 40), device="cuda"), | ||
torch.randn((1, 2, 40), device="cuda"), | ||
torch.randn((1, 2, 40), device="cuda"), | ||
) | ||
except Exception as e: | ||
raise e | ||
self.attn_fn = XFormersCrossAttentionProc(self.heads, self.upcast_attention) | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def set_attention_slice(self, slice_size): | ||
if slice_size is not None and slice_size > self.sliceable_head_dim: | ||
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") | ||
|
||
self._slice_size = slice_size | ||
|
||
def forward(self, hidden_states, context=None, mask=None): | ||
batch_size, sequence_length, _ = hidden_states.shape | ||
|
||
query = self.to_q(hidden_states) | ||
context = context if context is not None else hidden_states | ||
key = self.to_k(context) | ||
value = self.to_v(context) | ||
|
||
dim = query.shape[-1] | ||
|
||
query = self.reshape_heads_to_batch_dim(query) | ||
key = self.reshape_heads_to_batch_dim(key) | ||
value = self.reshape_heads_to_batch_dim(value) | ||
|
||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used | ||
|
||
# attention, what we cannot get enough of | ||
if self._use_memory_efficient_attention_xformers: | ||
hidden_states = self._memory_efficient_attention_xformers(query, key, value) | ||
# Some versions of xformers return output in fp32, cast it back to the dtype of the input | ||
hidden_states = hidden_states.to(query.dtype) | ||
else: | ||
if self._slice_size is None or query.shape[0] // self._slice_size == 1: | ||
hidden_states = self._attention(query, key, value) | ||
else: | ||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) | ||
|
||
self.attn_proc = SlicedAttentionProc(self.heads, self.upcast_attention) | ||
|
||
def set_cross_attn_proc(self, attn_proc: CrossAttentionProcMixin): | ||
if not isinstance(attn_proc, CrossAttentionProcMixin): | ||
subclass = attn_proc.__bases__ if hasattr(attn_proc, "__bases__") else None | ||
raise ValueError(f"`attn_proc` should be a subclass of {CrossAttentionProc}, but is of type {type(attn_proc)} and a subclass of {subclass}.") | ||
self.attn_proc = attn_proc | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bit of a python anti-pattern here to manually check the baseclasses include the mixin, no? Since the transformer class doesn't use any of the mixin internals, maybe just change to type signature to Callable and remove the manual type check? |
||
|
||
def forward(self, hidden_states, context=None, cross_attention_kwargs=None): | ||
# attn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the whole forward function should probably become the |
||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} | ||
hidden_states = self.attn_proc(hidden_states, self.to_q, self.to_k, self.to_v, context=context, **cross_attention_kwargs) | ||
# linear proj | ||
hidden_states = self.to_out[0](hidden_states) | ||
# dropout | ||
hidden_states = self.to_out[1](hidden_states) | ||
return hidden_states | ||
|
||
def _attention(self, query, key, value): | ||
if self.upcast_attention: | ||
query = query.float() | ||
key = key.float() | ||
|
||
attention_scores = torch.baddbmm( | ||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), | ||
query, | ||
key.transpose(-1, -2), | ||
beta=0, | ||
alpha=self.scale, | ||
) | ||
attention_probs = attention_scores.softmax(dim=-1) | ||
|
||
# cast back to the original dtype | ||
attention_probs = attention_probs.to(value.dtype) | ||
|
||
# compute attention output | ||
hidden_states = torch.bmm(attention_probs, value) | ||
|
||
# reshape hidden_states | ||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) | ||
return hidden_states | ||
|
||
def _sliced_attention(self, query, key, value, sequence_length, dim): | ||
batch_size_attention = query.shape[0] | ||
hidden_states = torch.zeros( | ||
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype | ||
) | ||
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] | ||
for i in range(hidden_states.shape[0] // slice_size): | ||
start_idx = i * slice_size | ||
end_idx = (i + 1) * slice_size | ||
|
||
query_slice = query[start_idx:end_idx] | ||
key_slice = key[start_idx:end_idx] | ||
|
||
if self.upcast_attention: | ||
query_slice = query_slice.float() | ||
key_slice = key_slice.float() | ||
|
||
attn_slice = torch.baddbmm( | ||
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), | ||
query_slice, | ||
key_slice.transpose(-1, -2), | ||
beta=0, | ||
alpha=self.scale, | ||
) | ||
attn_slice = attn_slice.softmax(dim=-1) | ||
|
||
# cast back to the original dtype | ||
attn_slice = attn_slice.to(value.dtype) | ||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) | ||
|
||
hidden_states[start_idx:end_idx] = attn_slice | ||
|
||
# reshape hidden_states | ||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) | ||
return hidden_states | ||
|
||
def _memory_efficient_attention_xformers(self, query, key, value): | ||
query = query.contiguous() | ||
key = key.contiguous() | ||
value = value.contiguous() | ||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) | ||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class FeedForward(nn.Module): | ||
r""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could also add this directly in the
BasicTransformerBlock
and so that the user has to go one less level deep.