Skip to content

Commit e11a51a

Browse files
committed
[wip] attention refactor
1 parent 125d783 commit e11a51a

File tree

9 files changed

+417
-92
lines changed

9 files changed

+417
-92
lines changed

Diff for: scripts/convert_deprecated_attention_block.py

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import argparse
2+
3+
from torch import nn
4+
5+
from diffusers import DiffusionPipeline
6+
from diffusers.models.attention import AttentionBlock, assert_no_deprecated_attention_blocks
7+
from diffusers.models.autoencoder_kl import AutoencoderKL
8+
from diffusers.models.unet_2d import UNet2DModel
9+
from diffusers.models.unet_2d_blocks import (
10+
AttnDownBlock2D,
11+
AttnDownEncoderBlock2D,
12+
AttnSkipDownBlock2D,
13+
AttnSkipUpBlock2D,
14+
AttnUpBlock2D,
15+
AttnUpDecoderBlock2D,
16+
UNetMidBlock2D,
17+
)
18+
from diffusers.models.vq_model import VQModel
19+
20+
21+
MODULES = [AutoencoderKL, VQModel, UNet2DModel]
22+
23+
UNET_BLOCKS = [
24+
UNetMidBlock2D,
25+
AttnDownBlock2D,
26+
AttnDownEncoderBlock2D,
27+
AttnSkipDownBlock2D,
28+
AttnUpBlock2D,
29+
AttnUpDecoderBlock2D,
30+
AttnSkipUpBlock2D,
31+
]
32+
33+
34+
unet_blocks_to_convert = []
35+
36+
37+
def patch_unet_block(unet_block_class):
38+
orig_constructor = unet_block_class.__init__
39+
40+
def new_constructor(self, *args, **kwargs):
41+
orig_constructor(self, *args, **kwargs)
42+
unet_blocks_to_convert.append(self)
43+
44+
def convert_attention_blocks(self):
45+
new_attentions = []
46+
47+
for attention_block in self.attentions:
48+
if isinstance(attention_block, AttentionBlock):
49+
new_attention_block = attention_block.as_cross_attention()
50+
else:
51+
new_attention_block = attention_block
52+
53+
new_attentions.append(new_attention_block)
54+
55+
self.attentions = nn.ModuleList(new_attentions)
56+
57+
unet_block_class.__init__ = new_constructor
58+
unet_block_class.convert_attention_blocks = convert_attention_blocks
59+
60+
61+
for unet_block_class in UNET_BLOCKS:
62+
patch_unet_block(unet_block_class)
63+
64+
65+
if __name__ == "__main__":
66+
parser = argparse.ArgumentParser()
67+
68+
parser.add_argument(
69+
"--pipeline",
70+
default=None,
71+
type=str,
72+
required=True,
73+
help="Pipeline to convert the deprecated `AttentionBlock` to `CrossAttention`",
74+
)
75+
76+
parser.add_argument(
77+
"--dump_path", default=None, type=str, required=True, help="Path to the save the converted pipeline."
78+
)
79+
80+
args = parser.parse_args()
81+
82+
print(f"loading original pipeline {args.pipeline}")
83+
84+
pipe = DiffusionPipeline.from_pretrained(args.pipeline)
85+
86+
any_converted = False
87+
88+
for attr_name in dir(pipe):
89+
attr = getattr(pipe, attr_name)
90+
91+
for module in MODULES:
92+
if isinstance(attr, module):
93+
print(
94+
f"converting `DiffusionPipeline.from_pretrained({args.pipeline}).{attr_name}.attention_block_type`"
95+
)
96+
attr.register_to_config(attention_block_type="CrossAttention")
97+
any_converted = True
98+
99+
for unet_block in unet_blocks_to_convert:
100+
print(f"converting {unet_block.__class__}.attentions")
101+
unet_block.convert_attention_blocks()
102+
any_converted = True
103+
104+
if not any_converted:
105+
print(f"`DiffusionPipeline.from_pretrained({args.pipeline})` did not have any deprecated attention blocks")
106+
else:
107+
print(f"Saving converted pipeline to {args.dump_path}")
108+
109+
pipe.save_pretrained(args.dump_path)
110+
111+
print("Checking converted pipeline has no deprecated attention blocks")
112+
113+
with assert_no_deprecated_attention_blocks():
114+
pipe = DiffusionPipeline.from_pretrained(args.dump_path)
115+
116+
print(f"Converted pipeline saved to {args.dump_path}")

Diff for: src/diffusers/models/attention.py

+63-63
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import math
14+
from contextlib import ContextDecorator
1515
from typing import Callable, Optional
1616

1717
import torch
1818
import torch.nn.functional as F
1919
from torch import nn
2020

21+
from ..utils import deprecate
2122
from ..utils.import_utils import is_xformers_available
22-
from .cross_attention import CrossAttention
23+
from .cross_attention import CrossAttention, SpatialAttnProcessor, XFormersSpatialAttnProcessor
2324
from .embeddings import CombinedTimestepLabelEmbeddings
2425

2526

@@ -57,6 +58,20 @@ def __init__(
5758
eps: float = 1e-5,
5859
):
5960
super().__init__()
61+
62+
if _assert_no_deprecated_attention_blocks > 0:
63+
raise AssertionError(
64+
"Deprecated `AttentionBlock` created while `assert_no_deprecated_attention_blocks` context manager"
65+
" active."
66+
)
67+
68+
deprecation_message = (
69+
"AttentionBlock has been deprecated and will be replaced with CrossAttention. TODO add upgrade"
70+
" instructions"
71+
)
72+
73+
deprecate("AttentionBlock", "1.0.0", deprecation_message, standard_warn=True)
74+
6075
self.channels = channels
6176

6277
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
@@ -74,20 +89,6 @@ def __init__(
7489
self._use_memory_efficient_attention_xformers = False
7590
self._attention_op = None
7691

77-
def reshape_heads_to_batch_dim(self, tensor):
78-
batch_size, seq_len, dim = tensor.shape
79-
head_size = self.num_heads
80-
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
81-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
82-
return tensor
83-
84-
def reshape_batch_dim_to_heads(self, tensor):
85-
batch_size, seq_len, dim = tensor.shape
86-
head_size = self.num_heads
87-
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
88-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
89-
return tensor
90-
9192
def set_use_memory_efficient_attention_xformers(
9293
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
9394
):
@@ -119,59 +120,43 @@ def set_use_memory_efficient_attention_xformers(
119120
self._attention_op = attention_op
120121

121122
def forward(self, hidden_states):
122-
residual = hidden_states
123-
batch, channel, height, width = hidden_states.shape
124-
125-
# norm
126-
hidden_states = self.group_norm(hidden_states)
127-
128-
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
129-
130-
# proj to q, k, v
131-
query_proj = self.query(hidden_states)
132-
key_proj = self.key(hidden_states)
133-
value_proj = self.value(hidden_states)
134-
135-
scale = 1 / math.sqrt(self.channels / self.num_heads)
123+
attn = self.as_cross_attention()
124+
hidden_states = attn(hidden_states)
136125

137-
query_proj = self.reshape_heads_to_batch_dim(query_proj)
138-
key_proj = self.reshape_heads_to_batch_dim(key_proj)
139-
value_proj = self.reshape_heads_to_batch_dim(value_proj)
126+
return hidden_states
140127

141-
if self._use_memory_efficient_attention_xformers:
142-
# Memory efficient attention
143-
hidden_states = xformers.ops.memory_efficient_attention(
144-
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
145-
)
146-
hidden_states = hidden_states.to(query_proj.dtype)
128+
def as_cross_attention(self):
129+
if self._attention_op is None:
130+
processor = SpatialAttnProcessor()
147131
else:
148-
attention_scores = torch.baddbmm(
149-
torch.empty(
150-
query_proj.shape[0],
151-
query_proj.shape[1],
152-
key_proj.shape[1],
153-
dtype=query_proj.dtype,
154-
device=query_proj.device,
155-
),
156-
query_proj,
157-
key_proj.transpose(-1, -2),
158-
beta=0,
159-
alpha=scale,
160-
)
161-
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
162-
hidden_states = torch.bmm(attention_probs, value_proj)
132+
processor = XFormersSpatialAttnProcessor(self._attention_op)
163133

164-
# reshape hidden_states
165-
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
166-
167-
# compute next hidden_states
168-
hidden_states = self.proj_attn(hidden_states)
134+
if self.num_head_size is None:
135+
# When `self.num_head_size` is None, there is a single attention head
136+
# of all the channels
137+
dim_head = self.channels
138+
else:
139+
dim_head = self.num_head_size
140+
141+
attn = CrossAttention(
142+
self.channels,
143+
heads=self.num_heads,
144+
dim_head=dim_head,
145+
bias=True,
146+
upcast_softmax=True,
147+
norm_num_groups=self.group_norm.num_groups,
148+
processor=processor,
149+
eps=self.group_norm.eps,
150+
rescale_output_factor=self.rescale_output_factor,
151+
)
169152

170-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
153+
attn.group_norm = self.group_norm
154+
attn.to_q = self.query
155+
attn.to_k = self.key
156+
attn.to_v = self.value
157+
attn.to_out[0] = self.proj_attn
171158

172-
# res connect and rescale
173-
hidden_states = (hidden_states + residual) / self.rescale_output_factor
174-
return hidden_states
159+
return attn
175160

176161

177162
class BasicTransformerBlock(nn.Module):
@@ -480,3 +465,18 @@ def forward(self, x, timestep, class_labels, hidden_dtype=None):
480465
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
481466
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
482467
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
468+
469+
470+
# tracks the number of `assert_no_deprecated_attention_blocks` decorators
471+
_assert_no_deprecated_attention_blocks = 0
472+
473+
474+
class assert_no_deprecated_attention_blocks(ContextDecorator):
475+
def __enter__(self):
476+
global _assert_no_deprecated_attention_blocks
477+
_assert_no_deprecated_attention_blocks += 1
478+
return self
479+
480+
def __exit__(self, *args):
481+
global _assert_no_deprecated_attention_blocks
482+
_assert_no_deprecated_attention_blocks -= 1

Diff for: src/diffusers/models/autoencoder_kl.py

+3
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
norm_num_groups: int = 32,
8080
sample_size: int = 32,
8181
scaling_factor: float = 0.18215,
82+
attention_block_type: str = "AttentionBlock",
8283
):
8384
super().__init__()
8485

@@ -92,6 +93,7 @@ def __init__(
9293
act_fn=act_fn,
9394
norm_num_groups=norm_num_groups,
9495
double_z=True,
96+
attention_block_type=attention_block_type,
9597
)
9698

9799
# pass init params to Decoder
@@ -103,6 +105,7 @@ def __init__(
103105
layers_per_block=layers_per_block,
104106
norm_num_groups=norm_num_groups,
105107
act_fn=act_fn,
108+
attention_block_type=attention_block_type,
106109
)
107110

108111
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)

0 commit comments

Comments
 (0)