Skip to content

Support multiple IP adapter in Flux #10775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Honey-666 opened this issue Feb 12, 2025 · 6 comments · Fixed by #10867
Closed

Support multiple IP adapter in Flux #10775

Honey-666 opened this issue Feb 12, 2025 · 6 comments · Fixed by #10867
Assignees
Labels
roadmap Add to current release roadmap

Comments

@Honey-666
Copy link
Contributor

When I pass the weights in the form of [0.4, 0.4], it tells me "Expected list of 19 scales, got 2."
pipe.set_ip_adapter_scale([0.4, 0.4])

@Honey-666
Copy link
Contributor Author

When I pass the weights in the form of [0.4, 0.4], it tells me "Expected list of 19 scales, got 2." pipe.set_ip_adapter_scale([0.4, 0.4])

@hlky

@hlky
Copy link
Contributor

hlky commented Feb 12, 2025

Multiple IP adapters for Flux is not yet implemented.

Passing List[float] to set_ip_adapter_scale is for per block scale.

@Honey-666
Copy link
Contributor Author

Multiple IP adapters for Flux is not yet implemented.

Passing List[float] to set_ip_adapter_scale is for per block scale.

thank you!

@hlky hlky changed the title Does Flux support multiple IP adapters? Support multiple IP adapter in Flux Feb 13, 2025
@hlky hlky self-assigned this Feb 13, 2025
@hlky hlky added the roadmap Add to current release roadmap label Feb 13, 2025
@guiyrt
Copy link
Contributor

guiyrt commented Feb 19, 2025

Hey @hlky, is this up for grabs? Can give it a go after wrapping the first unit on agents course 🤗

@hlky
Copy link
Contributor

hlky commented Feb 20, 2025

@guiyrt You are welcome to take it up and I am happy to assist.

Flux IPAdapter is different than other IPAdapters, the IPAdapter attention output is added to hidden_states at a later point. FluxIPAdapterJointAttnProcessor2_0 is partly ready for multiple images, it needs to add all ip_attn_output.

Single image usage examples can be found here and quantized examples here for reduced requirements. There are two versions, flux-ip-adapter and flux-ip-adapter-v2. flux-ip-adapter needs "true cfg" and flux-ip-adapter-v2 doesn't so is faster to test, choose either one, testing both isn't necessary.

flux-ip-adapter-v2 also natively supports multiple input images by using a batch of images, to support that and multi-IPAdapter type ip_adapter_image would need to support List[PipelineImageInput]. Version can be checked by transformer.encoder_hid_proj.image_projection_layers[0].num_image_text_embeds. However this would change the interface so can be left for a later PR, I'll raise it with the team and loop in integrators.

Let me know if you have any questions!

class IPAdapterAttnProcessor2_0(torch.nn.Module):

if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError(
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
f"({len(ip_hidden_states)})"
)
else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if mask is None:
continue
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
"[1, num_images_for_ip_adapter, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if mask.shape[1] != ip_state.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of ip images ({ip_state.shape[1]}) at index {index}"
)
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(scale)}) at index {index}"
)
else:
ip_adapter_masks = [None] * len(self.scale)
# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
skip = False
if isinstance(scale, list):
if all(s == 0 for s in scale):
skip = True
elif scale == 0:
skip = True
if not skip:
if mask is not None:
if not isinstance(scale, list):
scale = [scale] * mask.shape[1]
current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
_current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + scale * current_ip_hidden_states

class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):

ip_query = hidden_states_query_proj
ip_attn_output = None
# for ip-adapter
# TODO: support for multiple adapters
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_attn_output = F.scaled_dot_product_attention(
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_attn_output = scale * ip_attn_output
ip_attn_output = ip_attn_output.to(ip_query.dtype)
return hidden_states, encoder_hidden_states, ip_attn_output

attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs

if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output

num_image_text_embeds = 4
if state_dict["proj.weight"].shape[0] == 65536:
num_image_text_embeds = 16
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds

@guiyrt
Copy link
Contributor

guiyrt commented Feb 22, 2025

Thanks for centralizing all the critical info! Initial PR is open, I have a few open points but I'll put them there

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants