Skip to content

Commit 503ad9e

Browse files
committed
wip
1 parent c49e9ed commit 503ad9e

File tree

5 files changed

+208
-243
lines changed

5 files changed

+208
-243
lines changed

Diff for: src/diffusers/models/attention.py

+1-173
Original file line numberDiff line numberDiff line change
@@ -11,189 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import math
15-
from typing import Callable, Optional
14+
from typing import Optional
1615

1716
import torch
1817
import torch.nn.functional as F
1918
from torch import nn
2019

2120
from ..utils import maybe_allow_in_graph
22-
from ..utils.import_utils import is_xformers_available
2321
from .attention_processor import Attention
2422
from .embeddings import CombinedTimestepLabelEmbeddings
2523

2624

27-
if is_xformers_available():
28-
import xformers
29-
import xformers.ops
30-
else:
31-
xformers = None
32-
33-
34-
class AttentionBlock(nn.Module):
35-
"""
36-
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
37-
to the N-d case.
38-
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
39-
Uses three q, k, v linear layers to compute attention.
40-
41-
Parameters:
42-
channels (`int`): The number of channels in the input and output.
43-
num_head_channels (`int`, *optional*):
44-
The number of channels in each head. If None, then `num_heads` = 1.
45-
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
46-
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
47-
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
48-
"""
49-
50-
# IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
51-
52-
def __init__(
53-
self,
54-
channels: int,
55-
num_head_channels: Optional[int] = None,
56-
norm_num_groups: int = 32,
57-
rescale_output_factor: float = 1.0,
58-
eps: float = 1e-5,
59-
):
60-
super().__init__()
61-
self.channels = channels
62-
63-
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
64-
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
65-
66-
# define q,k,v as linear layers
67-
self.query = nn.Linear(channels, channels)
68-
self.key = nn.Linear(channels, channels)
69-
self.value = nn.Linear(channels, channels)
70-
71-
self.rescale_output_factor = rescale_output_factor
72-
self.proj_attn = nn.Linear(channels, channels, bias=True)
73-
74-
self._use_memory_efficient_attention_xformers = False
75-
self._use_2_0_attn = True
76-
self._attention_op = None
77-
78-
def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True):
79-
batch_size, seq_len, dim = tensor.shape
80-
head_size = self.num_heads
81-
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
82-
tensor = tensor.permute(0, 2, 1, 3)
83-
if merge_head_and_batch:
84-
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
85-
return tensor
86-
87-
def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True):
88-
head_size = self.num_heads
89-
90-
if unmerge_head_and_batch:
91-
batch_head_size, seq_len, dim = tensor.shape
92-
batch_size = batch_head_size // head_size
93-
94-
tensor = tensor.reshape(batch_size, head_size, seq_len, dim)
95-
else:
96-
batch_size, _, seq_len, dim = tensor.shape
97-
98-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size)
99-
return tensor
100-
101-
def set_use_memory_efficient_attention_xformers(
102-
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
103-
):
104-
if use_memory_efficient_attention_xformers:
105-
if not is_xformers_available():
106-
raise ModuleNotFoundError(
107-
(
108-
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
109-
" xformers"
110-
),
111-
name="xformers",
112-
)
113-
elif not torch.cuda.is_available():
114-
raise ValueError(
115-
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
116-
" only available for GPU "
117-
)
118-
else:
119-
try:
120-
# Make sure we can run the memory efficient attention
121-
_ = xformers.ops.memory_efficient_attention(
122-
torch.randn((1, 2, 40), device="cuda"),
123-
torch.randn((1, 2, 40), device="cuda"),
124-
torch.randn((1, 2, 40), device="cuda"),
125-
)
126-
except Exception as e:
127-
raise e
128-
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
129-
self._attention_op = attention_op
130-
131-
def forward(self, hidden_states):
132-
residual = hidden_states
133-
batch, channel, height, width = hidden_states.shape
134-
135-
# norm
136-
hidden_states = self.group_norm(hidden_states)
137-
138-
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
139-
140-
# proj to q, k, v
141-
query_proj = self.query(hidden_states)
142-
key_proj = self.key(hidden_states)
143-
value_proj = self.value(hidden_states)
144-
145-
scale = 1 / math.sqrt(self.channels / self.num_heads)
146-
147-
_use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers
148-
use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn
149-
150-
query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn)
151-
key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn)
152-
value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn)
153-
154-
if self._use_memory_efficient_attention_xformers:
155-
# Memory efficient attention
156-
hidden_states = xformers.ops.memory_efficient_attention(
157-
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale
158-
)
159-
hidden_states = hidden_states.to(query_proj.dtype)
160-
elif use_torch_2_0_attn:
161-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
162-
# TODO: add support for attn.scale when we move to Torch 2.1
163-
hidden_states = F.scaled_dot_product_attention(
164-
query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False
165-
)
166-
hidden_states = hidden_states.to(query_proj.dtype)
167-
else:
168-
attention_scores = torch.baddbmm(
169-
torch.empty(
170-
query_proj.shape[0],
171-
query_proj.shape[1],
172-
key_proj.shape[1],
173-
dtype=query_proj.dtype,
174-
device=query_proj.device,
175-
),
176-
query_proj,
177-
key_proj.transpose(-1, -2),
178-
beta=0,
179-
alpha=scale,
180-
)
181-
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
182-
hidden_states = torch.bmm(attention_probs, value_proj)
183-
184-
# reshape hidden_states
185-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn)
186-
187-
# compute next hidden_states
188-
hidden_states = self.proj_attn(hidden_states)
189-
190-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
191-
192-
# res connect and rescale
193-
hidden_states = (hidden_states + residual) / self.rescale_output_factor
194-
return hidden_states
195-
196-
19725
@maybe_allow_in_graph
19826
class BasicTransformerBlock(nn.Module):
19927
r"""

0 commit comments

Comments
 (0)