14
14
15
15
from compel .cross_attention_control import Arguments
16
16
from diffusers .models .unet_2d_condition import UNet2DConditionModel
17
- from diffusers .models .cross_attention import AttnProcessor
18
17
from ldm .invoke .devices import torch_dtype
19
18
20
19
@@ -163,7 +162,7 @@ def offload_saved_attention_slices_to_cpu(self):
163
162
164
163
class InvokeAICrossAttentionMixin :
165
164
"""
166
- Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
165
+ Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
167
166
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
168
167
and dymamic slicing strategy selection.
169
168
"""
@@ -178,7 +177,7 @@ def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, t
178
177
Set custom attention calculator to be called when attention is calculated
179
178
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
180
179
which returns either the suggested_attention_slice or an adjusted equivalent.
181
- `module` is the current CrossAttention module for which the callback is being invoked.
180
+ `module` is the current Attention module for which the callback is being invoked.
182
181
`suggested_attention_slice` is the default-calculated attention slice
183
182
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
184
183
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
@@ -326,7 +325,7 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
326
325
327
326
328
327
def get_cross_attention_modules (model , which : CrossAttentionType ) -> list [tuple [str , InvokeAICrossAttentionMixin ]]:
329
- from ldm .modules .attention import CrossAttention # avoid circular import
328
+ from ldm .modules .attention import CrossAttention # avoid circular import # TODO: rename as in diffusers?
330
329
cross_attention_class : type = InvokeAIDiffusersCrossAttention if isinstance (model ,UNet2DConditionModel ) else CrossAttention
331
330
which_attn = "attn1" if which is CrossAttentionType .SELF else "attn2"
332
331
attention_module_tuples = [(name ,module ) for name , module in model .named_modules () if
@@ -432,7 +431,7 @@ def get_mem_free_total(device):
432
431
433
432
434
433
435
- class InvokeAIDiffusersCrossAttention (diffusers .models .attention .CrossAttention , InvokeAICrossAttentionMixin ):
434
+ class InvokeAIDiffusersCrossAttention (diffusers .models .attention .Attention , InvokeAICrossAttentionMixin ):
436
435
437
436
def __init__ (self , ** kwargs ):
438
437
super ().__init__ (** kwargs )
@@ -457,8 +456,8 @@ def _attention(self, query, key, value, attention_mask=None):
457
456
"""
458
457
# base implementation
459
458
460
- class CrossAttnProcessor :
461
- def __call__(self, attn: CrossAttention , hidden_states, encoder_hidden_states=None, attention_mask=None):
459
+ class AttnProcessor :
460
+ def __call__(self, attn: Attention , hidden_states, encoder_hidden_states=None, attention_mask=None):
462
461
batch_size, sequence_length, _ = hidden_states.shape
463
462
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
464
463
@@ -487,7 +486,7 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
487
486
488
487
import torch
489
488
490
- from diffusers .models .cross_attention import CrossAttention , CrossAttnProcessor , SlicedAttnProcessor
489
+ from diffusers .models .attention_processor import Attention , AttnProcessor , SlicedAttnProcessor
491
490
492
491
493
492
@dataclass
@@ -532,7 +531,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
532
531
533
532
# TODO: dynamically pick slice size based on memory conditions
534
533
535
- def __call__ (self , attn : CrossAttention , hidden_states , encoder_hidden_states = None , attention_mask = None ,
534
+ def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None ,
536
535
# kwargs
537
536
swap_cross_attn_context : SwapCrossAttnContext = None ):
538
537
0 commit comments