Skip to content

attention refactor: the trilogy #3387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 1 addition & 173 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,189 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable, Optional
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn

from ..utils import maybe_allow_in_graph
from ..utils.import_utils import is_xformers_available
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings


if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None


class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention.

Parameters:
channels (`int`): The number of channels in the input and output.
num_head_channels (`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""

# IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore

def __init__(
self,
channels: int,
num_head_channels: Optional[int] = None,
norm_num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
):
super().__init__()
self.channels = channels

self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)

# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)

self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels, bias=True)

self._use_memory_efficient_attention_xformers = False
self._use_2_0_attn = True
self._attention_op = None

def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3)
if merge_head_and_batch:
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor

def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
head_size = self.num_heads

if unmerge_head_and_batch:
batch_head_size, seq_len, dim = tensor.shape
batch_size = batch_head_size // head_size

tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
else:
batch_size, _, seq_len, dim = tensor.shape

tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size)
return tensor

def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers:
if not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self._attention_op = attention_op

def forward(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape

# norm
hidden_states = self.group_norm(hidden_states)

hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)

# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)

scale = 1 / math.sqrt(self.channels / self.num_heads)

_use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn

query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn)

if self._use_memory_efficient_attention_xformers:
# Memory efficient attention
hidden_states = xformers.ops.memory_efficient_attention(
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale
)
hidden_states = hidden_states.to(query_proj.dtype)
elif use_torch_2_0_attn:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.to(query_proj.dtype)
else:
attention_scores = torch.baddbmm(
torch.empty(
query_proj.shape[0],
query_proj.shape[1],
key_proj.shape[1],
dtype=query_proj.dtype,
device=query_proj.device,
),
query_proj,
key_proj.transpose(-1, -2),
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
hidden_states = torch.bmm(attention_probs, value_proj)

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn)

# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)

hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)

# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states


@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
Expand Down
Loading