Skip to content

Multi-Image IP-Adapter #4882

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions invokeai/app/invocations/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class CLIPVisionModelField(BaseModel):


class IPAdapterField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
Expand All @@ -55,12 +55,12 @@ class IPAdapterOutput(BaseInvocationOutput):
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")


@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.0.0")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""

# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: IPAdapterModelField = InputField(
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
)
Expand Down
17 changes: 11 additions & 6 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def get_scheduler(
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.3.0",
version="1.4.0",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
Expand Down Expand Up @@ -491,16 +491,21 @@ def prep_ip_adapter_data(
context=context,
)

input_image = context.services.images.get_pil_image(single_ip_adapter.image.image_name)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_images = single_ip_adapter.image
if not isinstance(single_ipa_images, list):
single_ipa_images = [single_ipa_images]

single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images]

# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
# Get image embeddings from CLIP and ImageProjModel.
(
image_prompt_embeds,
uncond_image_prompt_embeds,
) = ip_adapter_model.get_image_embeds(input_image, image_encoder_model)
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model
)

conditioning_data.ip_adapter_conditioning.append(
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
)
Expand Down
21 changes: 18 additions & 3 deletions invokeai/backend/ip_adapter/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def __call__(
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Apply IP-Adapter attention.

Args:
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
"""
residual = hidden_states

if attn.spatial_norm is not None:
Expand Down Expand Up @@ -127,26 +133,35 @@ def __call__(
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ipa_embed.shape[2] == encoder_hidden_states.shape[2]
# The token_len dimensions should match.
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]

ip_hidden_states = ipa_embed

# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)

ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)

# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)

ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# The output of sdpa has shape: (batch, num_heads, seq_len, head_dim)
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)

# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)

# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)

ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)

# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)

hidden_states = hidden_states + scale * ip_hidden_states

# linear proj
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ class PostprocessingSettings:
class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoder conditioning embeddings.
Shape: (batch_size, num_tokens, encoding_dim).
Shape: (num_images, num_tokens, encoding_dim).
"""
uncond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoding embeddings to use for unconditional generation.
Shape: (batch_size, num_tokens, encoding_dim).
Shape: (num_images, num_tokens, encoding_dim).
"""


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,12 @@ def _apply_standard_conditioning(self, x, sigma, conditioning_data: Conditioning

cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.cat([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
torch.stack(
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
Expand Down Expand Up @@ -415,9 +418,10 @@ def _apply_standard_conditioning_sequentially(
# Run unconditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
ipa_conditioning.uncond_image_prompt_embeds
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
Expand All @@ -444,9 +448,10 @@ def _apply_standard_conditioning_sequentially(
# Run conditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
ipa_conditioning.cond_image_prompt_embeds
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
Expand Down
5 changes: 4 additions & 1 deletion tests/backend/ip_adapter/test_ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
ip_adapter.to(torch_device, dtype=torch.float32)
unet.to(torch_device, dtype=torch.float32)

cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]}
# ip_embeds shape: (batch_size, num_ip_images, seq_len, ip_image_embedding_len)
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)

cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
Expand Down