Skip to content

granular transformers block scales for IP Adapter #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/iartisanxl/diffusers_patch/ip_adapter_attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
the weight scale of image prompt.
"""

def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, block_transformer_name=None):
super().__init__()

if not hasattr(F, "scaled_dot_product_attention"):
Expand All @@ -120,6 +120,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.block_transformer_name = block_transformer_name

if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
Expand Down Expand Up @@ -226,7 +227,15 @@ def __call__(

current_ip_hidden_states = current_ip_hidden_states * mask_downsample

hidden_states = hidden_states + scale * current_ip_hidden_states
scale_value = scale

if isinstance(scale, dict):
if self.block_transformer_name in scale:
scale_value = scale[self.block_transformer_name]
else:
scale_value = 1.0

hidden_states = hidden_states + scale_value * current_ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
Expand Down
33 changes: 27 additions & 6 deletions src/iartisanxl/graph/nodes/ip_adapter_merge_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,40 @@ class IPAdapterMergeNode(Node):
OUTPUTS = ["ip_adapter"]

def __call__(self) -> dict:
self.unet.set_attn_processor(AttnProcessor2_0())

if self.ip_adapter is not None:
if self.ip_adapter is None:
self.unet.set_attn_processor(AttnProcessor2_0())
else:
ip_adapters = self.ip_adapter

if isinstance(ip_adapters, dict):
ip_adapters = [ip_adapters]

weights = []
scales = []
reload_weights = False

for ip_adapter in ip_adapters:
if ip_adapter.get("reload_weights", False):
reload_weights = True
ip_adapter["reload_weights"] = False

weights.append(ip_adapter["weights"])
scales.append(ip_adapter["scale"])

attn_procs = self.convert_ip_adapter_attn_to_diffusers(weights)
self.unet.set_attn_processor(attn_procs)
scale = 0.0

if ip_adapter.get("enabled", False):
scale = (
ip_adapter["granular_scale"]
if ip_adapter.get("granular_scale_enabled", False)
else ip_adapter.get("scale", 0.0)
)

scales.append(scale)

if reload_weights:
self.unet.set_attn_processor(AttnProcessor2_0())
attn_procs = self.convert_ip_adapter_attn_to_diffusers(weights)
self.unet.set_attn_processor(attn_procs)

for attn_processor in self.unet.attn_processors.values():
if isinstance(attn_processor, IPAdapterAttnProcessor2_0):
Expand Down Expand Up @@ -69,11 +86,15 @@ def convert_ip_adapter_attn_to_diffusers(self, state_dicts):
# IP-Adapter Plus
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]

name_parts = name.split(".")
block_transformer_name = ".".join(name_parts[:4])

attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
block_transformer_name=block_transformer_name,
).to(dtype=self.torch_dtype, device=self.device)

value_dict = {}
Expand Down
39 changes: 36 additions & 3 deletions src/iartisanxl/graph/nodes/ip_adapter_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,53 @@ class IPAdapterNode(Node):
OPTIONAL_INPUTS = ["mask_alpha_image"]
OUTPUTS = ["ip_adapter"]

def __init__(self, type_index: int, adapter_type: str, adapter_scale: float = None, **kwargs):
def __init__(
self,
type_index: int,
adapter_type: str,
adapter_scale: float = None,
granular_scale_enabled: bool = False,
granular_scale: dict = None,
**kwargs,
):
super().__init__(**kwargs)

self.type_index = type_index
self.adapter_type = adapter_type
self.adapter_scale = adapter_scale
self.adapter_granuler_scale = granular_scale
self.granular_scale_enabled = granular_scale_enabled
self.ip_image_prompt_embeds = None
self.reload_weights = True

self.clip_image_processor = CLIPImageProcessor()

def update_adapter(self, type_index: int, adapter_type: str, enabled: bool, adapter_scale: float = None):
def update_adapter(
self,
type_index: int,
adapter_type: str,
enabled: bool,
adapter_scale: float = None,
granular_scale_enabled: bool = False,
granular_scale: dict = None,
reload_weights: bool = False,
):
self.type_index = type_index
self.adapter_type = adapter_type
self.enabled = enabled
self.adapter_scale = adapter_scale
self.granular_scale_enabled = granular_scale_enabled
self.adapter_granuler_scale = granular_scale
self.reload_weights = reload_weights
self.set_updated()

def to_dict(self):
node_dict = super().to_dict()
node_dict["type_index"] = self.type_index
node_dict["adapter_type"] = self.adapter_type
node_dict["adapter_scale"] = self.adapter_scale
node_dict["granular_scale_enabled"] = self.granular_scale_enabled
node_dict["adapter_granuler_scale"] = self.adapter_granuler_scale
return node_dict

@classmethod
Expand All @@ -43,12 +68,16 @@ def from_dict(cls, node_dict, _callbacks=None):
node.type_index = node_dict["type_index"]
node.adapter_type = node_dict["adapter_type"]
node.adapter_scale = node_dict["adapter_scale"]
node.granular_scale_enabled = node_dict["granular_scale_enabled"]
node.adapter_granuler_scale = node_dict["adapter_granuler_scale"]
return node

def update_inputs(self, node_dict):
self.type_index = node_dict["type_index"]
self.adapter_type = node_dict["adapter_type"]
self.adapter_scale = node_dict["adapter_scale"]
self.granular_scale_enabled = node_dict["granular_scale_enabled"]
self.adapter_granuler_scale = node_dict["adapter_granuler_scale"]

def __call__(self) -> dict:
image_projection = self.convert_ip_adapter_image_proj_to_diffusers(self.ip_adapter_model["image_proj"])
Expand All @@ -65,7 +94,7 @@ def __call__(self) -> dict:
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(image, output_hidden_states)

# save_embeds = torch.cat([uncond_image_prompt_embeds, image_prompt_embeds])
# torch.save(save_embeds, "C:/Users/Ozzy/Desktop/iartisanxl_style_test.ipadpt")
# torch.save(save_embeds, "iartisanxl_style_test.ipadpt")

tensor_mask = None
if self.mask_alpha_image is not None:
Expand All @@ -81,9 +110,13 @@ def __call__(self) -> dict:
"weights": self.ip_adapter_model,
"image_prompt_embeds": image_prompt_embeds,
"uncond_image_prompt_embeds": uncond_image_prompt_embeds,
"enabled": self.enabled,
"scale": self.adapter_scale,
"granular_scale_enabled": self.granular_scale_enabled,
"granular_scale": self.adapter_granuler_scale,
"tensor_mask": tensor_mask,
"image_projection": image_projection,
"reload_weights": self.reload_weights,
}

return self.values
Expand Down
114 changes: 114 additions & 0 deletions src/iartisanxl/modules/common/ip_adapter/advanced_widget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from PyQt6.QtCore import Qt, pyqtSignal
from PyQt6.QtWidgets import QCheckBox, QFrame, QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget
from superqt import QLabeledDoubleSlider

from iartisanxl.modules.common.ip_adapter.ip_adapter_data_object import IPAdapterDataObject


class AdvancedWidget(QWidget):
advanced_canceled = pyqtSignal()
granular_enabled = pyqtSignal(bool)

def __init__(self, ip_adapter: IPAdapterDataObject):
super().__init__()

self.ip_adapter = ip_adapter
self.attention_values = {
"down_1": [1.0, 1.0],
"down_2": [1.0, 1.0],
"mid": [1.0],
"up_0": [1.0, 1.0, 1.0],
"up_1": [1.0, 1.0, 1.0],
}

self.sliders = {}
self.frames = []

self.init_ui()

def init_ui(self):
main_layout = QVBoxLayout()

granular_scales_checkbox = QCheckBox("Enable granular scales")
granular_scales_checkbox.stateChanged.connect(self.on_granular)
main_layout.addWidget(granular_scales_checkbox)

sections_layout = QHBoxLayout()

for section, values in self.attention_values.items():
frame = QFrame()
frame.setDisabled(True)
frame.setObjectName("block_frame")

blocks_layout = QVBoxLayout()
section_label = QLabel(f"{section.capitalize()} Blocks")
blocks_layout.addWidget(section_label)

# Loop and create all the sliders for the section
for i, value in enumerate(values):
attention_layout = QHBoxLayout()
attention_label = QLabel(
f"Attention {i+1}"
) # this is the number of the count in the total attention vars
attention_layout.addWidget(attention_label)
attention_slider = QLabeledDoubleSlider(Qt.Orientation.Horizontal)
attention_slider.setRange(0.0, 1.0)
attention_slider.setValue(value)
attention_slider.valueChanged.connect(lambda val, sec=section, idx=i: self.update_scale(val, sec, idx))
attention_layout.addWidget(attention_slider)
blocks_layout.addLayout(attention_layout)

self.sliders.setdefault(section, []).append(attention_slider)

frame.setLayout(blocks_layout)
sections_layout.addWidget(frame)
self.frames.append(frame)

main_layout.addLayout(sections_layout)
main_layout.addStretch()

button_layout = QHBoxLayout()
save_button = QPushButton("Set scales")
save_button.clicked.connect(self.on_save)
button_layout.addWidget(save_button)
cancel_button = QPushButton("Cancel")
cancel_button.clicked.connect(self.on_cancel)
button_layout.addWidget(cancel_button)

main_layout.addLayout(button_layout)

self.setLayout(main_layout)

def update_scale(self, value, section, index):
self.attention_values[section][index] = value

def on_cancel(self):
self.advanced_canceled.emit()

def on_save(self):
scales = {}
for section, values in self.attention_values.items():
if section.startswith("down_"):
block = "down_blocks"
block_num = section.split("_")[1]
elif section == "mid":
block = "mid_block"
block_num = ""
elif section.startswith("up_"):
block = "up_blocks"
block_num = section.split("_")[1]

for i, value in enumerate(values):
key = f"{block}.{block_num}.attentions.{i}"
scales[key] = value

self.ip_adapter.granular_scale = scales
self.advanced_canceled.emit()

def on_granular(self, state):
is_enabled = state != Qt.CheckState.Unchecked.value

for frame in self.frames:
frame.setEnabled(is_enabled)

self.granular_enabled.emit(is_enabled)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class IPAdapterDataObject:
adapter_type: str = attr.ib(default=None)
type_index: int = attr.ib(default=0)
ip_adapter_scale: float = attr.ib(default=1.0)
granular_scale_enabled: bool = attr.ib(default=False)
granular_scale: dict = attr.ib(default=None)
enabled: bool = attr.ib(default=True)
node_id: int = attr.ib(default=None)
adapter_id: int = attr.ib(default=None)
Expand Down
Loading