@@ -906,6 +906,177 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
906
906
return self .processor (self , hidden_states )
907
907
908
908
909
+ class MochiAttention (nn .Module ):
910
+ def __init__ (
911
+ self ,
912
+ query_dim : int ,
913
+ added_kv_proj_dim : int ,
914
+ processor : "MochiAttnProcessor2_0" ,
915
+ heads : int = 8 ,
916
+ dim_head : int = 64 ,
917
+ dropout : float = 0.0 ,
918
+ bias : bool = False ,
919
+ added_proj_bias : bool = True ,
920
+ out_dim : Optional [int ] = None ,
921
+ out_context_dim : Optional [int ] = None ,
922
+ out_bias : bool = True ,
923
+ context_pre_only : bool = False ,
924
+ eps : float = 1e-5 ,
925
+ ):
926
+ super ().__init__ ()
927
+ from .normalization import MochiRMSNorm
928
+
929
+ self .inner_dim = out_dim if out_dim is not None else dim_head * heads
930
+ self .out_dim = out_dim if out_dim is not None else query_dim
931
+ self .out_context_dim = out_context_dim if out_context_dim else query_dim
932
+ self .context_pre_only = context_pre_only
933
+
934
+ self .heads = out_dim // dim_head if out_dim is not None else heads
935
+
936
+ self .norm_q = MochiRMSNorm (dim_head , eps , True )
937
+ self .norm_k = MochiRMSNorm (dim_head , eps , True )
938
+ self .norm_added_q = MochiRMSNorm (dim_head , eps , True )
939
+ self .norm_added_k = MochiRMSNorm (dim_head , eps , True )
940
+
941
+ self .to_q = nn .Linear (query_dim , self .inner_dim , bias = bias )
942
+ self .to_k = nn .Linear (query_dim , self .inner_dim , bias = bias )
943
+ self .to_v = nn .Linear (query_dim , self .inner_dim , bias = bias )
944
+
945
+ self .add_k_proj = nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
946
+ self .add_v_proj = nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
947
+ if self .context_pre_only is not None :
948
+ self .add_q_proj = nn .Linear (added_kv_proj_dim , self .inner_dim , bias = added_proj_bias )
949
+
950
+ self .to_out = nn .ModuleList ([])
951
+ self .to_out .append (nn .Linear (self .inner_dim , self .out_dim , bias = out_bias ))
952
+ self .to_out .append (nn .Dropout (dropout ))
953
+
954
+ if not self .context_pre_only :
955
+ self .to_add_out = nn .Linear (self .inner_dim , self .out_context_dim , bias = out_bias )
956
+
957
+ self .processor = processor
958
+
959
+ def forward (
960
+ self ,
961
+ hidden_states : torch .Tensor ,
962
+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
963
+ attention_mask : Optional [torch .Tensor ] = None ,
964
+ ** kwargs ,
965
+ ):
966
+ return self .processor (
967
+ self ,
968
+ hidden_states ,
969
+ encoder_hidden_states = encoder_hidden_states ,
970
+ attention_mask = attention_mask ,
971
+ ** kwargs ,
972
+ )
973
+
974
+
975
+ class MochiAttnProcessor2_0 :
976
+ """Attention processor used in Mochi."""
977
+
978
+ def __init__ (self ):
979
+ if not hasattr (F , "scaled_dot_product_attention" ):
980
+ raise ImportError ("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
981
+
982
+ def __call__ (
983
+ self ,
984
+ attn : "MochiAttention" ,
985
+ hidden_states : torch .Tensor ,
986
+ encoder_hidden_states : torch .Tensor ,
987
+ attention_mask : torch .Tensor ,
988
+ image_rotary_emb : Optional [torch .Tensor ] = None ,
989
+ ) -> torch .Tensor :
990
+ query = attn .to_q (hidden_states )
991
+ key = attn .to_k (hidden_states )
992
+ value = attn .to_v (hidden_states )
993
+
994
+ query = query .unflatten (2 , (attn .heads , - 1 ))
995
+ key = key .unflatten (2 , (attn .heads , - 1 ))
996
+ value = value .unflatten (2 , (attn .heads , - 1 ))
997
+
998
+ if attn .norm_q is not None :
999
+ query = attn .norm_q (query )
1000
+ if attn .norm_k is not None :
1001
+ key = attn .norm_k (key )
1002
+
1003
+ encoder_query = attn .add_q_proj (encoder_hidden_states )
1004
+ encoder_key = attn .add_k_proj (encoder_hidden_states )
1005
+ encoder_value = attn .add_v_proj (encoder_hidden_states )
1006
+
1007
+ encoder_query = encoder_query .unflatten (2 , (attn .heads , - 1 ))
1008
+ encoder_key = encoder_key .unflatten (2 , (attn .heads , - 1 ))
1009
+ encoder_value = encoder_value .unflatten (2 , (attn .heads , - 1 ))
1010
+
1011
+ if attn .norm_added_q is not None :
1012
+ encoder_query = attn .norm_added_q (encoder_query )
1013
+ if attn .norm_added_k is not None :
1014
+ encoder_key = attn .norm_added_k (encoder_key )
1015
+
1016
+ if image_rotary_emb is not None :
1017
+
1018
+ def apply_rotary_emb (x , freqs_cos , freqs_sin ):
1019
+ x_even = x [..., 0 ::2 ].float ()
1020
+ x_odd = x [..., 1 ::2 ].float ()
1021
+
1022
+ cos = (x_even * freqs_cos - x_odd * freqs_sin ).to (x .dtype )
1023
+ sin = (x_even * freqs_sin + x_odd * freqs_cos ).to (x .dtype )
1024
+
1025
+ return torch .stack ([cos , sin ], dim = - 1 ).flatten (- 2 )
1026
+
1027
+ query = apply_rotary_emb (query , * image_rotary_emb )
1028
+ key = apply_rotary_emb (key , * image_rotary_emb )
1029
+
1030
+ query , key , value = query .transpose (1 , 2 ), key .transpose (1 , 2 ), value .transpose (1 , 2 )
1031
+ encoder_query , encoder_key , encoder_value = (
1032
+ encoder_query .transpose (1 , 2 ),
1033
+ encoder_key .transpose (1 , 2 ),
1034
+ encoder_value .transpose (1 , 2 ),
1035
+ )
1036
+
1037
+ sequence_length = query .size (2 )
1038
+ encoder_sequence_length = encoder_query .size (2 )
1039
+ total_length = sequence_length + encoder_sequence_length
1040
+
1041
+ batch_size , heads , _ , dim = query .shape
1042
+ attn_outputs = []
1043
+ for idx in range (batch_size ):
1044
+ mask = attention_mask [idx ][None , :]
1045
+ valid_prompt_token_indices = torch .nonzero (mask .flatten (), as_tuple = False ).flatten ()
1046
+
1047
+ valid_encoder_query = encoder_query [idx : idx + 1 , :, valid_prompt_token_indices , :]
1048
+ valid_encoder_key = encoder_key [idx : idx + 1 , :, valid_prompt_token_indices , :]
1049
+ valid_encoder_value = encoder_value [idx : idx + 1 , :, valid_prompt_token_indices , :]
1050
+
1051
+ valid_query = torch .cat ([query [idx : idx + 1 ], valid_encoder_query ], dim = 2 )
1052
+ valid_key = torch .cat ([key [idx : idx + 1 ], valid_encoder_key ], dim = 2 )
1053
+ valid_value = torch .cat ([value [idx : idx + 1 ], valid_encoder_value ], dim = 2 )
1054
+
1055
+ attn_output = F .scaled_dot_product_attention (
1056
+ valid_query , valid_key , valid_value , dropout_p = 0.0 , is_causal = False
1057
+ )
1058
+ valid_sequence_length = attn_output .size (2 )
1059
+ attn_output = F .pad (attn_output , (0 , 0 , 0 , total_length - valid_sequence_length ))
1060
+ attn_outputs .append (attn_output )
1061
+
1062
+ hidden_states = torch .cat (attn_outputs , dim = 0 )
1063
+ hidden_states = hidden_states .transpose (1 , 2 ).flatten (2 , 3 )
1064
+
1065
+ hidden_states , encoder_hidden_states = hidden_states .split_with_sizes (
1066
+ (sequence_length , encoder_sequence_length ), dim = 1
1067
+ )
1068
+
1069
+ # linear proj
1070
+ hidden_states = attn .to_out [0 ](hidden_states )
1071
+ # dropout
1072
+ hidden_states = attn .to_out [1 ](hidden_states )
1073
+
1074
+ if hasattr (attn , "to_add_out" ):
1075
+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
1076
+
1077
+ return hidden_states , encoder_hidden_states
1078
+
1079
+
909
1080
class AttnProcessor :
910
1081
r"""
911
1082
Default processor for performing attention-related computations.
@@ -3868,94 +4039,6 @@ def __call__(
3868
4039
return hidden_states
3869
4040
3870
4041
3871
- class MochiAttnProcessor2_0 :
3872
- """Attention processor used in Mochi."""
3873
-
3874
- def __init__ (self ):
3875
- if not hasattr (F , "scaled_dot_product_attention" ):
3876
- raise ImportError ("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
3877
-
3878
- def __call__ (
3879
- self ,
3880
- attn : Attention ,
3881
- hidden_states : torch .Tensor ,
3882
- encoder_hidden_states : torch .Tensor ,
3883
- attention_mask : Optional [torch .Tensor ] = None ,
3884
- image_rotary_emb : Optional [torch .Tensor ] = None ,
3885
- ) -> torch .Tensor :
3886
- query = attn .to_q (hidden_states )
3887
- key = attn .to_k (hidden_states )
3888
- value = attn .to_v (hidden_states )
3889
-
3890
- query = query .unflatten (2 , (attn .heads , - 1 ))
3891
- key = key .unflatten (2 , (attn .heads , - 1 ))
3892
- value = value .unflatten (2 , (attn .heads , - 1 ))
3893
-
3894
- if attn .norm_q is not None :
3895
- query = attn .norm_q (query )
3896
- if attn .norm_k is not None :
3897
- key = attn .norm_k (key )
3898
-
3899
- encoder_query = attn .add_q_proj (encoder_hidden_states )
3900
- encoder_key = attn .add_k_proj (encoder_hidden_states )
3901
- encoder_value = attn .add_v_proj (encoder_hidden_states )
3902
-
3903
- encoder_query = encoder_query .unflatten (2 , (attn .heads , - 1 ))
3904
- encoder_key = encoder_key .unflatten (2 , (attn .heads , - 1 ))
3905
- encoder_value = encoder_value .unflatten (2 , (attn .heads , - 1 ))
3906
-
3907
- if attn .norm_added_q is not None :
3908
- encoder_query = attn .norm_added_q (encoder_query )
3909
- if attn .norm_added_k is not None :
3910
- encoder_key = attn .norm_added_k (encoder_key )
3911
-
3912
- if image_rotary_emb is not None :
3913
-
3914
- def apply_rotary_emb (x , freqs_cos , freqs_sin ):
3915
- x_even = x [..., 0 ::2 ].float ()
3916
- x_odd = x [..., 1 ::2 ].float ()
3917
-
3918
- cos = (x_even * freqs_cos - x_odd * freqs_sin ).to (x .dtype )
3919
- sin = (x_even * freqs_sin + x_odd * freqs_cos ).to (x .dtype )
3920
-
3921
- return torch .stack ([cos , sin ], dim = - 1 ).flatten (- 2 )
3922
-
3923
- query = apply_rotary_emb (query , * image_rotary_emb )
3924
- key = apply_rotary_emb (key , * image_rotary_emb )
3925
-
3926
- query , key , value = query .transpose (1 , 2 ), key .transpose (1 , 2 ), value .transpose (1 , 2 )
3927
- encoder_query , encoder_key , encoder_value = (
3928
- encoder_query .transpose (1 , 2 ),
3929
- encoder_key .transpose (1 , 2 ),
3930
- encoder_value .transpose (1 , 2 ),
3931
- )
3932
-
3933
- sequence_length = query .size (2 )
3934
- encoder_sequence_length = encoder_query .size (2 )
3935
-
3936
- query = torch .cat ([query , encoder_query ], dim = 2 )
3937
- key = torch .cat ([key , encoder_key ], dim = 2 )
3938
- value = torch .cat ([value , encoder_value ], dim = 2 )
3939
-
3940
- hidden_states = F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
3941
- hidden_states = hidden_states .transpose (1 , 2 ).flatten (2 , 3 )
3942
- hidden_states = hidden_states .to (query .dtype )
3943
-
3944
- hidden_states , encoder_hidden_states = hidden_states .split_with_sizes (
3945
- (sequence_length , encoder_sequence_length ), dim = 1
3946
- )
3947
-
3948
- # linear proj
3949
- hidden_states = attn .to_out [0 ](hidden_states )
3950
- # dropout
3951
- hidden_states = attn .to_out [1 ](hidden_states )
3952
-
3953
- if getattr (attn , "to_add_out" , None ) is not None :
3954
- encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
3955
-
3956
- return hidden_states , encoder_hidden_states
3957
-
3958
-
3959
4042
class FusedAttnProcessor2_0 :
3960
4043
r"""
3961
4044
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
@@ -5668,13 +5751,13 @@ def __call__(
5668
5751
AttnProcessorNPU ,
5669
5752
AttnProcessor2_0 ,
5670
5753
MochiVaeAttnProcessor2_0 ,
5754
+ MochiAttnProcessor2_0 ,
5671
5755
StableAudioAttnProcessor2_0 ,
5672
5756
HunyuanAttnProcessor2_0 ,
5673
5757
FusedHunyuanAttnProcessor2_0 ,
5674
5758
PAGHunyuanAttnProcessor2_0 ,
5675
5759
PAGCFGHunyuanAttnProcessor2_0 ,
5676
5760
LuminaAttnProcessor2_0 ,
5677
- MochiAttnProcessor2_0 ,
5678
5761
FusedAttnProcessor2_0 ,
5679
5762
CustomDiffusionXFormersAttnProcessor ,
5680
5763
CustomDiffusionAttnProcessor2_0 ,
0 commit comments