Skip to content

Support multiple IP-Adapters (workflow editor only) #4818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 46 additions & 35 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import UNet2DConditionModel
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
from diffusers.models.attention_processor import (
AttnProcessor2_0,
Expand Down Expand Up @@ -199,7 +198,7 @@ def get_scheduler(
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.2.0",
version="1.3.0",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
Expand All @@ -226,7 +225,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
input=Input.Connection,
ui_order=5,
)
ip_adapter: Optional[IPAdapterField] = InputField(
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
)
t2i_adapter: Union[T2IAdapterField, list[T2IAdapterField]] = InputField(
Expand Down Expand Up @@ -410,52 +409,65 @@ def prep_control_data(
def prep_ip_adapter_data(
self,
context: InvocationContext,
ip_adapter: Optional[IPAdapterField],
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
conditioning_data: ConditioningData,
unet: UNet2DConditionModel,
exit_stack: ExitStack,
) -> Optional[IPAdapterData]:
) -> Optional[list[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
to the `conditioning_data` (in-place).
"""
if ip_adapter is None:
return None

image_encoder_model_info = context.services.model_manager.get_model(
model_name=ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=ip_adapter.image_encoder_model.base_model,
context=context,
)
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
if not isinstance(ip_adapter, list):
ip_adapter = [ip_adapter]

if len(ip_adapter) == 0:
return None

ip_adapter_data_list = []
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=single_ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=single_ip_adapter.ip_adapter_model.base_model,
context=context,
)
)

ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=ip_adapter.ip_adapter_model.base_model,
image_encoder_model_info = context.services.model_manager.get_model(
model_name=single_ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=single_ip_adapter.image_encoder_model.base_model,
context=context,
)
)

input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
input_image = context.services.images.get_pil_image(single_ip_adapter.image.image_name)

# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
input_image, image_encoder_model
)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
input_image, image_encoder_model
)
conditioning_data.ip_adapter_conditioning.append(
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
)

ip_adapter_data_list.append(
IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=single_ip_adapter.weight,
begin_step_percent=single_ip_adapter.begin_step_percent,
end_step_percent=single_ip_adapter.end_step_percent,
)
)

return IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=ip_adapter.weight,
begin_step_percent=ip_adapter.begin_step_percent,
end_step_percent=ip_adapter.end_step_percent,
)
return ip_adapter_data_list

def run_t2i_adapters(
self,
Expand Down Expand Up @@ -677,7 +689,6 @@ def _lora_loader():
context=context,
ip_adapter=self.ip_adapter,
conditioning_data=conditioning_data,
unet=unet,
exit_stack=exit_stack,
)

Expand Down
61 changes: 32 additions & 29 deletions invokeai/backend/ip_adapter/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0

from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights


# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
# loading.
Expand Down Expand Up @@ -45,18 +47,16 @@ class IPAttnProcessor2_0(torch.nn.Module):
the weight scale of image prompt.
"""

def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
super().__init__()

if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
assert len(weights) == len(scales)

self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self._weights = weights
self._scales = scales

def __call__(
self,
Expand All @@ -67,16 +67,6 @@ def __call__(
temb=None,
ip_adapter_image_prompt_embeds=None,
):
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
# The batch dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
ip_hidden_states = ip_adapter_image_prompt_embeds

residual = hidden_states

if attn.spatial_norm is not None:
Expand Down Expand Up @@ -128,23 +118,36 @@ def __call__(
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)

for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ipa_embed.shape[2] == encoder_hidden_states.shape[2]

ip_hidden_states = ipa_embed

ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# The output of sdpa has shape: (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)

ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)

hidden_states = hidden_states + self.scale * ip_hidden_states
hidden_states = hidden_states + scale * ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down
104 changes: 8 additions & 96 deletions invokeai/backend/ip_adapter/ip_adapter.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed

from contextlib import contextmanager
from typing import Optional, Union

import torch
from diffusers.models import UNet2DConditionModel
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.model_management.models.base import calc_model_size_by_data

from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from .resampler import Resampler


Expand Down Expand Up @@ -61,7 +59,7 @@ class IPAdapter:

def __init__(
self,
state_dict: dict[torch.Tensor],
state_dict: dict[str, torch.Tensor],
device: torch.device,
dtype: torch.dtype = torch.float16,
num_tokens: int = 4,
Expand All @@ -73,112 +71,26 @@ def __init__(

self._clip_image_processor = CLIPImageProcessor()

self._state_dict = state_dict
self._image_proj_model = self._init_image_proj_model(state_dict["image_proj"])

self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])

# The _attn_processors will be initialized later when we have access to the UNet.
self._attn_processors = None
self.attn_weights = IPAttentionWeights.from_state_dict(state_dict["ip_adapter"]).to(
self.device, dtype=self.dtype
)

def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device
if dtype is not None:
self.dtype = dtype

self._image_proj_model.to(device=self.device, dtype=self.dtype)
if self._attn_processors is not None:
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
self.attn_weights.to(device=self.device, dtype=self.dtype)

def calc_size(self):
if self._state_dict is not None:
image_proj_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["image_proj"].values()]
)
ip_adapter_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["ip_adapter"].values()]
)
return image_proj_size + ip_adapter_size
else:
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(
torch.nn.ModuleList(self._attn_processors.values())
)
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)

def _init_image_proj_model(self, state_dict):
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)

def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can later be injected into a unet, and load the IP-Adapter
attention weights into them.

Note that the `unet` param is only used to determine attention block dimensions and naming.
TODO(ryand): As a future improvement, this could all be inferred from the state_dict when the IPAdapter is
intialized.
"""
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self.device, dtype=self.dtype)

ip_layers = torch.nn.ModuleList(attn_procs.values())
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
self._attn_processors = attn_procs
self._state_dict = None

# @genomancer: pushed scaling back out into its own method (like original Tencent implementation)
# which makes implementing begin_step_percent and end_step_percent easier
# but based on self._attn_processors (ala @Ryan) instead of original Tencent unet.attn_processors,
# which should make it easier to implement multiple IPAdapters
def set_scale(self, scale):
if self._attn_processors is not None:
for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor2_0):
attn_processor.scale = scale

@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel, scale: float):
"""A context manager that patches `unet` with this IP-Adapter's attention processors while it is active.

Yields:
None
"""
if self._attn_processors is None:
# We only have to call _prepare_attention_processors(...) once, and then the result is cached and can be
# used on any UNet model (with the same dimensions).
self._prepare_attention_processors(unet)

# Set scale
self.set_scale(scale)
# for attn_processor in self._attn_processors.values():
# if isinstance(attn_processor, IPAttnProcessor2_0):
# attn_processor.scale = scale

orig_attn_processors = unet.attn_processors

# Make a (moderately-) shallow copy of the self._attn_processors dict, because unet.set_attn_processor(...)
# actually pops elements from the passed dict.
ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()}

try:
unet.set_attn_processor(ip_adapter_attn_processors)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)

@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
Expand Down
Loading