10
10
import psutil
11
11
import torch
12
12
from compel .cross_attention_control import Arguments
13
- from diffusers .models .cross_attention import AttnProcessor
13
+ from diffusers .models .attention_processor import AttentionProcessor
14
14
from diffusers .models .unet_2d_condition import UNet2DConditionModel
15
15
from torch import nn
16
16
@@ -188,7 +188,7 @@ def offload_saved_attention_slices_to_cpu(self):
188
188
189
189
class InvokeAICrossAttentionMixin :
190
190
"""
191
- Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
191
+ Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
192
192
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
193
193
and dymamic slicing strategy selection.
194
194
"""
@@ -209,7 +209,7 @@ def set_attention_slice_wrangler(
209
209
Set custom attention calculator to be called when attention is calculated
210
210
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
211
211
which returns either the suggested_attention_slice or an adjusted equivalent.
212
- `module` is the current CrossAttention module for which the callback is being invoked.
212
+ `module` is the current Attention module for which the callback is being invoked.
213
213
`suggested_attention_slice` is the default-calculated attention slice
214
214
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
215
215
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
@@ -345,11 +345,11 @@ def get_invokeai_attention_mem_efficient(self, q, k, v):
345
345
def restore_default_cross_attention (
346
346
model ,
347
347
is_running_diffusers : bool ,
348
- restore_attention_processor : Optional [AttnProcessor ] = None ,
348
+ restore_attention_processor : Optional [AttentionProcessor ] = None ,
349
349
):
350
350
if is_running_diffusers :
351
351
unet = model
352
- unet .set_attn_processor (restore_attention_processor or CrossAttnProcessor ())
352
+ unet .set_attn_processor (restore_attention_processor or AttnProcessor ())
353
353
else :
354
354
remove_attention_function (model )
355
355
@@ -408,7 +408,7 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
408
408
def get_cross_attention_modules (
409
409
model , which : CrossAttentionType
410
410
) -> list [tuple [str , InvokeAICrossAttentionMixin ]]:
411
- from ldm .modules .attention import CrossAttention # avoid circular import
411
+ from ldm .modules .attention import CrossAttention # avoid circular import - TODO: rename as in diffusers?
412
412
413
413
cross_attention_class : type = (
414
414
InvokeAIDiffusersCrossAttention
@@ -428,10 +428,10 @@ def get_cross_attention_modules(
428
428
print (
429
429
f"Error! CrossAttentionControl found an unexpected number of { cross_attention_class } modules in the model "
430
430
+ f"(expected { expected_count } , found { cross_attention_modules_in_model_count } ). Either monkey-patching failed "
431
- + f "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
431
+ + "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
432
432
+ f"and/or update the { expected_count } above to an appropriate number, and/or find and inform someone who knows "
433
- + f "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
434
- + f "work properly until it is fixed."
433
+ + "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
434
+ + "work properly until it is fixed."
435
435
)
436
436
return attention_module_tuples
437
437
@@ -550,7 +550,7 @@ def get_mem_free_total(device):
550
550
551
551
552
552
class InvokeAIDiffusersCrossAttention (
553
- diffusers .models .attention .CrossAttention , InvokeAICrossAttentionMixin
553
+ diffusers .models .attention .Attention , InvokeAICrossAttentionMixin
554
554
):
555
555
def __init__ (self , ** kwargs ):
556
556
super ().__init__ (** kwargs )
@@ -572,8 +572,8 @@ def _attention(self, query, key, value, attention_mask=None):
572
572
"""
573
573
# base implementation
574
574
575
- class CrossAttnProcessor :
576
- def __call__(self, attn: CrossAttention , hidden_states, encoder_hidden_states=None, attention_mask=None):
575
+ class AttnProcessor :
576
+ def __call__(self, attn: Attention , hidden_states, encoder_hidden_states=None, attention_mask=None):
577
577
batch_size, sequence_length, _ = hidden_states.shape
578
578
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
579
579
@@ -601,9 +601,9 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
601
601
from dataclasses import dataclass , field
602
602
603
603
import torch
604
- from diffusers .models .cross_attention import (
605
- CrossAttention ,
606
- CrossAttnProcessor ,
604
+ from diffusers .models .attention_processor import (
605
+ Attention ,
606
+ AttnProcessor ,
607
607
SlicedAttnProcessor ,
608
608
)
609
609
@@ -653,7 +653,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
653
653
654
654
def __call__ (
655
655
self ,
656
- attn : CrossAttention ,
656
+ attn : Attention ,
657
657
hidden_states ,
658
658
encoder_hidden_states = None ,
659
659
attention_mask = None ,
0 commit comments