Skip to content

Commit 589bc20

Browse files
committed
Tidy _prepare_attention_processors(...) logic.
1 parent 1430b3a commit 589bc20

File tree

1 file changed

+2
-15
lines changed

1 file changed

+2
-15
lines changed

invokeai/backend/ip_adapter/unet_patcher.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,10 @@ def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[
1313
1414
Note that the `unet` param is only used to determine attention block dimensions and naming.
1515
"""
16-
# TODO(ryand): This logic can be simplified.
17-
1816
# Construct a dict of attention processors based on the UNet's architecture.
1917
attn_procs = {}
2018
for idx, name in enumerate(unet.attn_processors.keys()):
21-
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
22-
if name.startswith("mid_block"):
23-
hidden_size = unet.config.block_out_channels[-1]
24-
elif name.startswith("up_blocks"):
25-
block_id = int(name[len("up_blocks.")])
26-
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
27-
elif name.startswith("down_blocks"):
28-
block_id = int(name[len("down_blocks.")])
29-
hidden_size = unet.config.block_out_channels[block_id]
30-
31-
if cross_attention_dim is None:
19+
if name.endswith("attn1.processor"):
3220
attn_procs[name] = AttnProcessor2_0()
3321
else:
3422
# Collect the weights from each IP Adapter for the idx'th attention processor.
@@ -43,8 +31,7 @@ def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPA
4331
"""A context manager that patches `unet` with IP-Adapter attention processors.
4432
4533
Yields:
46-
Scales: The Scales object, which can be used to dynamically alter the scales of the
47-
IP-Adapters.
34+
Scales: The Scales object, which can be used to dynamically alter the scales of the IP-Adapters.
4835
"""
4936
scales = Scales([1.0] * len(ip_adapters))
5037

0 commit comments

Comments
 (0)