Skip to content

Commit 690ebe1

Browse files
committed
Fix handling of scales with multiple IP-Adapters.
1 parent c899925 commit 690ebe1

File tree

5 files changed

+54
-25
lines changed

5 files changed

+54
-25
lines changed

invokeai/backend/ip_adapter/attention_processor.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
1010

1111
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
12+
from invokeai.backend.ip_adapter.scales import Scales
1213

1314

1415
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
@@ -47,13 +48,16 @@ class IPAttnProcessor2_0(torch.nn.Module):
4748
the weight scale of image prompt.
4849
"""
4950

50-
def __init__(self, weights: list[IPAttentionProcessorWeights]):
51+
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: Scales):
5152
super().__init__()
5253

5354
if not hasattr(F, "scaled_dot_product_attention"):
5455
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
5556

56-
self.weights = weights
57+
assert len(weights) == len(scales)
58+
59+
self._weights = weights
60+
self._scales = scales
5761

5862
def __call__(
5963
self,
@@ -119,9 +123,11 @@ def __call__(
119123
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
120124
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
121125
assert ip_adapter_image_prompt_embeds is not None
122-
assert len(ip_adapter_image_prompt_embeds) == len(self.weights)
126+
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
123127

124-
for ipa_embed, ipa_weights in zip(ip_adapter_image_prompt_embeds, self.weights):
128+
for ipa_embed, ipa_weights, scale in zip(
129+
ip_adapter_image_prompt_embeds, self._weights, self._scales.scales
130+
):
125131
# The batch dimensions should match.
126132
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
127133
# The channel dimensions should match.
@@ -144,7 +150,7 @@ def __call__(
144150
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
145151
ip_hidden_states = ip_hidden_states.to(query.dtype)
146152

147-
hidden_states = hidden_states + ipa_weights.scale * ip_hidden_states
153+
hidden_states = hidden_states + scale * ip_hidden_states
148154

149155
# linear proj
150156
hidden_states = attn.to_out[0](hidden_states)

invokeai/backend/ip_adapter/ip_attention_weights.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ class IPAttentionProcessorWeights(torch.nn.Module):
88
method.
99
"""
1010

11-
def __init__(self, in_dim: int, out_dim: int, scale: float = 1.0):
11+
def __init__(self, in_dim: int, out_dim: int):
1212
super().__init__()
13-
self.scale = scale
1413
self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
1514
self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
1615

@@ -26,11 +25,6 @@ def __init__(self, weights: torch.nn.ModuleDict):
2625
super().__init__()
2726
self._weights = weights
2827

29-
def set_scale(self, scale: float):
30-
"""Set the scale (a.k.a. 'weight') for all of the `IPAttentionProcessorWeights` in this collection."""
31-
for w in self._weights.values():
32-
w.scale = scale
33-
3428
def get_attention_processor_weights(self, idx: int) -> IPAttentionProcessorWeights:
3529
"""Get the `IPAttentionProcessorWeights` for the idx'th attention processor."""
3630
# Cast to int first, because we expect the key to represent an int. Then cast back to str, because

invokeai/backend/ip_adapter/scales.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
class Scales:
2+
"""The IP-Adapter scales for a patched UNet. This object can be used to dynamically change the scales for a patched
3+
UNet.
4+
"""
5+
6+
def __init__(self, scales: list[float]):
7+
self._scales = scales
8+
9+
@property
10+
def scales(self):
11+
return self._scales
12+
13+
@scales.setter
14+
def scales(self, scales: list[float]):
15+
assert len(scales) == len(self._scales)
16+
self._scales = scales
17+
18+
def __len__(self):
19+
return len(self._scales)

invokeai/backend/ip_adapter/unet_patcher.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44

55
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
66
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
7+
from invokeai.backend.ip_adapter.scales import Scales
78

89

9-
def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
10+
def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter], scales: Scales):
1011
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
1112
weights into them.
1213
@@ -32,15 +33,22 @@ def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[
3233
else:
3334
# Collect the weights from each IP Adapter for the idx'th attention processor.
3435
attn_procs[name] = IPAttnProcessor2_0(
35-
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in ip_adapters]
36+
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in ip_adapters], scales
3637
)
3738
return attn_procs
3839

3940

4041
@contextmanager
4142
def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
42-
"""A context manager that patches `unet` with IP-Adapter attention processors."""
43-
attn_procs = _prepare_attention_processors(unet, ip_adapters)
43+
"""A context manager that patches `unet` with IP-Adapter attention processors.
44+
45+
Yields:
46+
Scales: The Scales object, which can be used to dynamically alter the scales of the
47+
IP-Adapters.
48+
"""
49+
scales = Scales([1.0] * len(ip_adapters))
50+
51+
attn_procs = _prepare_attention_processors(unet, ip_adapters, scales)
4452

4553
orig_attn_processors = unet.attn_processors
4654

@@ -49,6 +57,6 @@ def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPA
4957
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
5058
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
5159
unet.set_attn_processor(attn_procs)
52-
yield None
60+
yield scales
5361
finally:
5462
unet.set_attn_processor(orig_attn_processors)

invokeai/backend/stable_diffusion/diffusers_pipeline.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from invokeai.app.services.config import InvokeAIAppConfig
2626
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
27-
from invokeai.backend.ip_adapter.unet_patcher import apply_ip_adapter_attention
27+
from invokeai.backend.ip_adapter.unet_patcher import Scales, apply_ip_adapter_attention
2828
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
2929

3030
from ..util import auto_detect_slice_size, normalize_device
@@ -426,7 +426,7 @@ def generate_latents_from_embeddings(
426426
return latents, attention_map_saver
427427

428428
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
429-
attn_ctx = self.invokeai_diffuser.custom_attention_context(
429+
attn_ctx_mgr = self.invokeai_diffuser.custom_attention_context(
430430
self.invokeai_diffuser.model,
431431
extra_conditioning_info=conditioning_data.extra,
432432
step_count=len(self.scheduler.timesteps),
@@ -435,14 +435,14 @@ def generate_latents_from_embeddings(
435435
elif ip_adapter_data is not None:
436436
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
437437
# As it is now, the IP-Adapter will silently be skipped.
438-
attn_ctx = apply_ip_adapter_attention(
438+
attn_ctx_mgr = apply_ip_adapter_attention(
439439
unet=self.invokeai_diffuser.model, ip_adapters=[ipa.ip_adapter_model for ipa in ip_adapter_data]
440440
)
441441
self.use_ip_adapter = True
442442
else:
443-
attn_ctx = nullcontext()
443+
attn_ctx_mgr = nullcontext()
444444

445-
with attn_ctx:
445+
with attn_ctx_mgr as attn_ctx:
446446
if callback is not None:
447447
callback(
448448
PipelineIntermediateState(
@@ -467,6 +467,7 @@ def generate_latents_from_embeddings(
467467
control_data=control_data,
468468
ip_adapter_data=ip_adapter_data,
469469
t2i_adapter_data=t2i_adapter_data,
470+
attn_ctx=attn_ctx,
470471
)
471472
latents = step_output.prev_sample
472473

@@ -514,6 +515,7 @@ def step(
514515
control_data: List[ControlNetData] = None,
515516
ip_adapter_data: Optional[list[IPAdapterData]] = None,
516517
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
518+
attn_ctx: Optional[Scales] = None,
517519
):
518520
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
519521
timestep = t[0]
@@ -526,7 +528,7 @@ def step(
526528

527529
# handle IP-Adapter
528530
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
529-
for single_ip_adapter_data in ip_adapter_data:
531+
for i, single_ip_adapter_data in enumerate(ip_adapter_data):
530532
first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
531533
last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
532534
weight = (
@@ -536,10 +538,10 @@ def step(
536538
)
537539
if step_index >= first_adapter_step and step_index <= last_adapter_step:
538540
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
539-
single_ip_adapter_data.ip_adapter_model.attn_weights.set_scale(weight)
541+
attn_ctx.scales[i] = weight
540542
else:
541543
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
542-
single_ip_adapter_data.ip_adapter_model.attn_weights.set_scale(0.0)
544+
attn_ctx.scales[i] = weight
543545

544546
# Handle ControlNet(s) and T2I-Adapter(s)
545547
down_block_additional_residuals = None

0 commit comments

Comments
 (0)