@@ -406,6 +406,9 @@ def __init__(
406
406
):
407
407
super ().__init__ ()
408
408
self .only_cross_attention = only_cross_attention
409
+ self .use_ada_layer_norm = num_embeds_ada_norm is not None
410
+
411
+ # 1. Self-Attn
409
412
self .attn1 = CrossAttention (
410
413
query_dim = dim ,
411
414
heads = num_attention_heads ,
@@ -415,23 +418,28 @@ def __init__(
415
418
cross_attention_dim = cross_attention_dim if only_cross_attention else None ,
416
419
) # is a self-attention
417
420
self .ff = FeedForward (dim , dropout = dropout , activation_fn = activation_fn )
418
- self .attn2 = CrossAttention (
419
- query_dim = dim ,
420
- cross_attention_dim = cross_attention_dim ,
421
- heads = num_attention_heads ,
422
- dim_head = attention_head_dim ,
423
- dropout = dropout ,
424
- bias = attention_bias ,
425
- ) # is self-attn if context is none
426
421
427
- # layer norms
428
- self .use_ada_layer_norm = num_embeds_ada_norm is not None
429
- if self .use_ada_layer_norm :
430
- self .norm1 = AdaLayerNorm (dim , num_embeds_ada_norm )
431
- self .norm2 = AdaLayerNorm (dim , num_embeds_ada_norm )
422
+ # 2. Cross-Attn
423
+ if cross_attention_dim is not None :
424
+ self .attn2 = CrossAttention (
425
+ query_dim = dim ,
426
+ cross_attention_dim = cross_attention_dim ,
427
+ heads = num_attention_heads ,
428
+ dim_head = attention_head_dim ,
429
+ dropout = dropout ,
430
+ bias = attention_bias ,
431
+ ) # is self-attn if context is none
432
432
else :
433
- self .norm1 = nn .LayerNorm (dim )
434
- self .norm2 = nn .LayerNorm (dim )
433
+ self .attn2 = None
434
+
435
+ self .norm1 = AdaLayerNorm (dim , num_embeds_ada_norm ) if self .use_ada_layer_norm else nn .LayerNorm (dim )
436
+
437
+ if cross_attention_dim is not None :
438
+ self .norm2 = AdaLayerNorm (dim , num_embeds_ada_norm ) if self .use_ada_layer_norm else nn .LayerNorm (dim )
439
+ else :
440
+ self .norm2 = None
441
+
442
+ # 3. Feed-forward
435
443
self .norm3 = nn .LayerNorm (dim )
436
444
437
445
# if xformers is installed try to use memory_efficient_attention by default
@@ -481,11 +489,12 @@ def forward(self, hidden_states, context=None, timestep=None):
481
489
else :
482
490
hidden_states = self .attn1 (norm_hidden_states ) + hidden_states
483
491
484
- # 2. Cross-Attention
485
- norm_hidden_states = (
486
- self .norm2 (hidden_states , timestep ) if self .use_ada_layer_norm else self .norm2 (hidden_states )
487
- )
488
- hidden_states = self .attn2 (norm_hidden_states , context = context ) + hidden_states
492
+ if self .attn2 is not None :
493
+ # 2. Cross-Attention
494
+ norm_hidden_states = (
495
+ self .norm2 (hidden_states , timestep ) if self .use_ada_layer_norm else self .norm2 (hidden_states )
496
+ )
497
+ hidden_states = self .attn2 (norm_hidden_states , context = context ) + hidden_states
489
498
490
499
# 3. Feed-forward
491
500
hidden_states = self .ff (self .norm3 (hidden_states )) + hidden_states
@@ -666,14 +675,16 @@ def __init__(
666
675
inner_dim = int (dim * mult )
667
676
dim_out = dim_out if dim_out is not None else dim
668
677
669
- if activation_fn == "geglu" :
670
- geglu = GEGLU (dim , inner_dim )
678
+ if activation_fn == "gelu" :
679
+ act_fn = GELU (dim , inner_dim )
680
+ elif activation_fn == "geglu" :
681
+ act_fn = GEGLU (dim , inner_dim )
671
682
elif activation_fn == "geglu-approximate" :
672
- geglu = ApproximateGELU (dim , inner_dim )
683
+ act_fn = ApproximateGELU (dim , inner_dim )
673
684
674
685
self .net = nn .ModuleList ([])
675
686
# project in
676
- self .net .append (geglu )
687
+ self .net .append (act_fn )
677
688
# project dropout
678
689
self .net .append (nn .Dropout (dropout ))
679
690
# project out
@@ -685,6 +696,27 @@ def forward(self, hidden_states):
685
696
return hidden_states
686
697
687
698
699
+ class GELU (nn .Module ):
700
+ r"""
701
+ GELU activation function
702
+ """
703
+
704
+ def __init__ (self , dim_in : int , dim_out : int ):
705
+ super ().__init__ ()
706
+ self .proj = nn .Linear (dim_in , dim_out )
707
+
708
+ def gelu (self , gate ):
709
+ if gate .device .type != "mps" :
710
+ return F .gelu (gate )
711
+ # mps: gelu is not implemented for float16
712
+ return F .gelu (gate .to (dtype = torch .float32 )).to (dtype = gate .dtype )
713
+
714
+ def forward (self , hidden_states ):
715
+ hidden_states = self .proj (hidden_states )
716
+ hidden_states = self .gelu (hidden_states )
717
+ return hidden_states
718
+
719
+
688
720
# feedforward
689
721
class GEGLU (nn .Module ):
690
722
r"""
0 commit comments