Skip to content

Commit 6b71b39

Browse files
committed
Refactor multi-IP-Adapter to clean up the interface around changing scales.
1 parent 591e717 commit 6b71b39

File tree

6 files changed

+63
-81
lines changed

6 files changed

+63
-81
lines changed

invokeai/app/invocations/latent.py

-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torchvision.transforms as T
1111
from diffusers import AutoencoderKL, AutoencoderTiny
1212
from diffusers.image_processor import VaeImageProcessor
13-
from diffusers.models import UNet2DConditionModel
1413
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
1514
from diffusers.models.attention_processor import (
1615
AttnProcessor2_0,

invokeai/backend/ip_adapter/attention_processor.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
1010

1111
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
12-
from invokeai.backend.ip_adapter.scales import Scales
1312

1413

1514
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
@@ -48,7 +47,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
4847
the weight scale of image prompt.
4948
"""
5049

51-
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: Scales):
50+
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
5251
super().__init__()
5352

5453
if not hasattr(F, "scaled_dot_product_attention"):
@@ -125,9 +124,7 @@ def __call__(
125124
assert ip_adapter_image_prompt_embeds is not None
126125
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
127126

128-
for ipa_embed, ipa_weights, scale in zip(
129-
ip_adapter_image_prompt_embeds, self._weights, self._scales.scales
130-
):
127+
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
131128
# The batch dimensions should match.
132129
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
133130
# The channel dimensions should match.

invokeai/backend/ip_adapter/scales.py

-19
This file was deleted.

invokeai/backend/ip_adapter/unet_patcher.py

+47-43
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,50 @@
44

55
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
66
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
7-
from invokeai.backend.ip_adapter.scales import Scales
8-
9-
10-
def _prepare_attention_processors(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter], scales: Scales):
11-
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
12-
weights into them.
13-
14-
Note that the `unet` param is only used to determine attention block dimensions and naming.
15-
"""
16-
# Construct a dict of attention processors based on the UNet's architecture.
17-
attn_procs = {}
18-
for idx, name in enumerate(unet.attn_processors.keys()):
19-
if name.endswith("attn1.processor"):
20-
attn_procs[name] = AttnProcessor2_0()
21-
else:
22-
# Collect the weights from each IP Adapter for the idx'th attention processor.
23-
attn_procs[name] = IPAttnProcessor2_0(
24-
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in ip_adapters], scales
25-
)
26-
return attn_procs
27-
28-
29-
@contextmanager
30-
def apply_ip_adapter_attention(unet: UNet2DConditionModel, ip_adapters: list[IPAdapter]):
31-
"""A context manager that patches `unet` with IP-Adapter attention processors.
32-
33-
Yields:
34-
Scales: The Scales object, which can be used to dynamically alter the scales of the IP-Adapters.
35-
"""
36-
scales = Scales([1.0] * len(ip_adapters))
37-
38-
attn_procs = _prepare_attention_processors(unet, ip_adapters, scales)
39-
40-
orig_attn_processors = unet.attn_processors
41-
42-
try:
43-
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
44-
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
45-
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
46-
unet.set_attn_processor(attn_procs)
47-
yield scales
48-
finally:
49-
unet.set_attn_processor(orig_attn_processors)
7+
8+
9+
class UNetPatcher:
10+
"""A class that contains multiple IP-Adapters and can apply them to a UNet."""
11+
12+
def __init__(self, ip_adapters: list[IPAdapter]):
13+
self._ip_adapters = ip_adapters
14+
self._scales = [1.0] * len(self._ip_adapters)
15+
16+
def set_scale(self, idx: int, value: float):
17+
self._scales[idx] = value
18+
19+
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
20+
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
21+
weights into them.
22+
23+
Note that the `unet` param is only used to determine attention block dimensions and naming.
24+
"""
25+
# Construct a dict of attention processors based on the UNet's architecture.
26+
attn_procs = {}
27+
for idx, name in enumerate(unet.attn_processors.keys()):
28+
if name.endswith("attn1.processor"):
29+
attn_procs[name] = AttnProcessor2_0()
30+
else:
31+
# Collect the weights from each IP Adapter for the idx'th attention processor.
32+
attn_procs[name] = IPAttnProcessor2_0(
33+
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
34+
self._scales,
35+
)
36+
return attn_procs
37+
38+
@contextmanager
39+
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
40+
"""A context manager that patches `unet` with IP-Adapter attention processors."""
41+
42+
attn_procs = self._prepare_attention_processors(unet)
43+
44+
orig_attn_processors = unet.attn_processors
45+
46+
try:
47+
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
48+
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
49+
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
50+
unet.set_attn_processor(attn_procs)
51+
yield None
52+
finally:
53+
unet.set_attn_processor(orig_attn_processors)

invokeai/backend/stable_diffusion/diffusers_pipeline.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from invokeai.app.services.config import InvokeAIAppConfig
2626
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
27-
from invokeai.backend.ip_adapter.unet_patcher import Scales, apply_ip_adapter_attention
27+
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
2828
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
2929

3030
from ..util import auto_detect_slice_size, normalize_device
@@ -425,8 +425,9 @@ def generate_latents_from_embeddings(
425425
if timesteps.shape[0] == 0:
426426
return latents, attention_map_saver
427427

428+
ip_adapter_unet_patcher = None
428429
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
429-
attn_ctx_mgr = self.invokeai_diffuser.custom_attention_context(
430+
attn_ctx = self.invokeai_diffuser.custom_attention_context(
430431
self.invokeai_diffuser.model,
431432
extra_conditioning_info=conditioning_data.extra,
432433
step_count=len(self.scheduler.timesteps),
@@ -435,14 +436,13 @@ def generate_latents_from_embeddings(
435436
elif ip_adapter_data is not None:
436437
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
437438
# As it is now, the IP-Adapter will silently be skipped.
438-
attn_ctx_mgr = apply_ip_adapter_attention(
439-
unet=self.invokeai_diffuser.model, ip_adapters=[ipa.ip_adapter_model for ipa in ip_adapter_data]
440-
)
439+
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
440+
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
441441
self.use_ip_adapter = True
442442
else:
443-
attn_ctx_mgr = nullcontext()
443+
attn_ctx = nullcontext()
444444

445-
with attn_ctx_mgr as attn_ctx:
445+
with attn_ctx:
446446
if callback is not None:
447447
callback(
448448
PipelineIntermediateState(
@@ -467,7 +467,7 @@ def generate_latents_from_embeddings(
467467
control_data=control_data,
468468
ip_adapter_data=ip_adapter_data,
469469
t2i_adapter_data=t2i_adapter_data,
470-
attn_ctx=attn_ctx,
470+
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
471471
)
472472
latents = step_output.prev_sample
473473

@@ -515,7 +515,7 @@ def step(
515515
control_data: List[ControlNetData] = None,
516516
ip_adapter_data: Optional[list[IPAdapterData]] = None,
517517
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
518-
attn_ctx: Optional[Scales] = None,
518+
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
519519
):
520520
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
521521
timestep = t[0]
@@ -538,10 +538,10 @@ def step(
538538
)
539539
if step_index >= first_adapter_step and step_index <= last_adapter_step:
540540
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
541-
attn_ctx.scales[i] = weight
541+
ip_adapter_unet_patcher.set_scale(i, weight)
542542
else:
543543
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
544-
attn_ctx.scales[i] = 0.0
544+
ip_adapter_unet_patcher.set_scale(i, 0.0)
545545

546546
# Handle ControlNet(s) and T2I-Adapter(s)
547547
down_block_additional_residuals = None

tests/backend/ip_adapter/test_ip_adapter.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33

4-
from invokeai.backend.ip_adapter.unet_patcher import apply_ip_adapter_attention
4+
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
55
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
66
from invokeai.backend.util.test_utils import install_and_load_model, slow
77

@@ -66,7 +66,8 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
6666
unet.to(torch_device, dtype=torch.float32)
6767

6868
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]}
69-
with apply_ip_adapter_attention(unet, [ip_adapter]):
69+
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
70+
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
7071
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
7172

7273
assert output.shape == dummy_unet_input["sample"].shape

0 commit comments

Comments
 (0)