@@ -13,22 +13,10 @@ def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[
13
13
14
14
Note that the `unet` param is only used to determine attention block dimensions and naming.
15
15
"""
16
- # TODO(ryand): This logic can be simplified.
17
-
18
16
# Construct a dict of attention processors based on the UNet's architecture.
19
17
attn_procs = {}
20
18
for idx , name in enumerate (unet .attn_processors .keys ()):
21
- cross_attention_dim = None if name .endswith ("attn1.processor" ) else unet .config .cross_attention_dim
22
- if name .startswith ("mid_block" ):
23
- hidden_size = unet .config .block_out_channels [- 1 ]
24
- elif name .startswith ("up_blocks" ):
25
- block_id = int (name [len ("up_blocks." )])
26
- hidden_size = list (reversed (unet .config .block_out_channels ))[block_id ]
27
- elif name .startswith ("down_blocks" ):
28
- block_id = int (name [len ("down_blocks." )])
29
- hidden_size = unet .config .block_out_channels [block_id ]
30
-
31
- if cross_attention_dim is None :
19
+ if name .endswith ("attn1.processor" ):
32
20
attn_procs [name ] = AttnProcessor2_0 ()
33
21
else :
34
22
# Collect the weights from each IP Adapter for the idx'th attention processor.
@@ -43,8 +31,7 @@ def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPA
43
31
"""A context manager that patches `unet` with IP-Adapter attention processors.
44
32
45
33
Yields:
46
- Scales: The Scales object, which can be used to dynamically alter the scales of the
47
- IP-Adapters.
34
+ Scales: The Scales object, which can be used to dynamically alter the scales of the IP-Adapters.
48
35
"""
49
36
scales = Scales ([1.0 ] * len (ip_adapters ))
50
37
0 commit comments