Skip to content

Commit 47b9910

Browse files
committed
update to diffusers 0.15 and fix code for name changes
- This is a port of #3184 to the main branch
1 parent 23d65e7 commit 47b9910

File tree

4 files changed

+32
-26
lines changed

4 files changed

+32
-26
lines changed

invokeai/backend/stable_diffusion/diffusers_pipeline.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,15 @@ def device(self) -> torch.device:
445445
@property
446446
def _submodels(self) -> Sequence[torch.nn.Module]:
447447
module_names, _, _ = self.extract_init_dict(dict(self.config))
448-
values = [getattr(self, name) for name in module_names.keys()]
449-
return [m for m in values if isinstance(m, torch.nn.Module)]
448+
submodels = []
449+
for name in module_names.keys():
450+
if hasattr(self, name):
451+
value = getattr(self, name)
452+
else:
453+
value = getattr(self.config, name)
454+
if isinstance(value, torch.nn.Module):
455+
submodels.append(value)
456+
return submodels
450457

451458
def image_from_embeddings(
452459
self,
@@ -544,7 +551,7 @@ def generate_latents_from_embeddings(
544551
yield PipelineIntermediateState(
545552
run_id=run_id,
546553
step=-1,
547-
timestep=self.scheduler.num_train_timesteps,
554+
timestep=self.scheduler.config.num_train_timesteps,
548555
latents=latents,
549556
)
550557

@@ -915,7 +922,7 @@ def get_learned_conditioning(
915922
@property
916923
def channels(self) -> int:
917924
"""Compatible with DiffusionWrapper"""
918-
return self.unet.in_channels
925+
return self.unet.config.in_channels
919926

920927
def decode_latents(self, latents):
921928
# Explicit call to get the vae loaded, since `decode` isn't the forward method.

invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import psutil
1111
import torch
1212
from compel.cross_attention_control import Arguments
13-
from diffusers.models.cross_attention import AttnProcessor
13+
from diffusers.models.attention_processor import AttentionProcessor
1414
from diffusers.models.unet_2d_condition import UNet2DConditionModel
1515
from torch import nn
1616

@@ -188,7 +188,7 @@ def offload_saved_attention_slices_to_cpu(self):
188188

189189
class InvokeAICrossAttentionMixin:
190190
"""
191-
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
191+
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
192192
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
193193
and dymamic slicing strategy selection.
194194
"""
@@ -209,7 +209,7 @@ def set_attention_slice_wrangler(
209209
Set custom attention calculator to be called when attention is calculated
210210
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
211211
which returns either the suggested_attention_slice or an adjusted equivalent.
212-
`module` is the current CrossAttention module for which the callback is being invoked.
212+
`module` is the current Attention module for which the callback is being invoked.
213213
`suggested_attention_slice` is the default-calculated attention slice
214214
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
215215
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
@@ -345,11 +345,11 @@ def get_invokeai_attention_mem_efficient(self, q, k, v):
345345
def restore_default_cross_attention(
346346
model,
347347
is_running_diffusers: bool,
348-
restore_attention_processor: Optional[AttnProcessor] = None,
348+
restore_attention_processor: Optional[AttentionProcessor] = None,
349349
):
350350
if is_running_diffusers:
351351
unet = model
352-
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
352+
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
353353
else:
354354
remove_attention_function(model)
355355

@@ -408,7 +408,7 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
408408
def get_cross_attention_modules(
409409
model, which: CrossAttentionType
410410
) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
411-
from ldm.modules.attention import CrossAttention # avoid circular import
411+
from ldm.modules.attention import CrossAttention # avoid circular import - TODO: rename as in diffusers?
412412

413413
cross_attention_class: type = (
414414
InvokeAIDiffusersCrossAttention
@@ -428,10 +428,10 @@ def get_cross_attention_modules(
428428
print(
429429
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
430430
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
431-
+ f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
431+
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
432432
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
433-
+ f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
434-
+ f"work properly until it is fixed."
433+
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
434+
+ "work properly until it is fixed."
435435
)
436436
return attention_module_tuples
437437

@@ -550,7 +550,7 @@ def get_mem_free_total(device):
550550

551551

552552
class InvokeAIDiffusersCrossAttention(
553-
diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin
553+
diffusers.models.attention.Attention, InvokeAICrossAttentionMixin
554554
):
555555
def __init__(self, **kwargs):
556556
super().__init__(**kwargs)
@@ -572,8 +572,8 @@ def _attention(self, query, key, value, attention_mask=None):
572572
"""
573573
# base implementation
574574
575-
class CrossAttnProcessor:
576-
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
575+
class AttnProcessor:
576+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
577577
batch_size, sequence_length, _ = hidden_states.shape
578578
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
579579
@@ -601,9 +601,9 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
601601
from dataclasses import dataclass, field
602602

603603
import torch
604-
from diffusers.models.cross_attention import (
605-
CrossAttention,
606-
CrossAttnProcessor,
604+
from diffusers.models.attention_processor import (
605+
Attention,
606+
AttnProcessor,
607607
SlicedAttnProcessor,
608608
)
609609

@@ -653,7 +653,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
653653

654654
def __call__(
655655
self,
656-
attn: CrossAttention,
656+
attn: Attention,
657657
hidden_states,
658658
encoder_hidden_states=None,
659659
attention_mask=None,

invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
import torch
8-
from diffusers.models.cross_attention import AttnProcessor
8+
from diffusers.models.attention_processor import AttentionProcessor
99
from typing_extensions import TypeAlias
1010

1111
from invokeai.backend.globals import Globals
@@ -101,7 +101,7 @@ def custom_attention_context(
101101

102102
def override_cross_attention(
103103
self, conditioning: ExtraConditioningInfo, step_count: int
104-
) -> Dict[str, AttnProcessor]:
104+
) -> Dict[str, AttentionProcessor]:
105105
"""
106106
setup cross attention .swap control. for diffusers this replaces the attention processor, so
107107
the previous attention processor is returned so that the caller can restore it later.
@@ -118,7 +118,7 @@ def override_cross_attention(
118118
)
119119

120120
def restore_default_cross_attention(
121-
self, restore_attention_processor: Optional["AttnProcessor"] = None
121+
self, restore_attention_processor: Optional["AttentionProcessor"] = None
122122
):
123123
self.conditioning = None
124124
self.cross_attention_control_context = None
@@ -262,7 +262,7 @@ def calculate_percent_through(self, sigma, step_index, total_step_count):
262262
# TODO remove when compvis codepath support is dropped
263263
if step_index is None and sigma is None:
264264
raise ValueError(
265-
f"Either step_index or sigma is required when doing cross attention control, but both are None."
265+
"Either step_index or sigma is required when doing cross attention control, but both are None."
266266
)
267267
percent_through = self.estimate_percent_through(step_index, sigma)
268268
return percent_through
@@ -599,7 +599,6 @@ def apply_conjunction(
599599
)
600600

601601
# below is fugly omg
602-
num_actual_conditionings = len(c_or_weighted_c_list)
603602
conditionings = [uc] + [c for c, weight in weighted_cond_list]
604603
weights = [1] + [weight for c, weight in weighted_cond_list]
605604
chunk_count = ceil(len(conditionings) / 2)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ dependencies = [
4040
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
4141
"compel==1.0.5",
4242
"datasets",
43-
"diffusers[torch]==0.14",
43+
"diffusers[torch]==0.15.*",
4444
"dnspython==2.2.1",
4545
"einops",
4646
"eventlet",

0 commit comments

Comments
 (0)