@@ -973,3 +973,178 @@ def swap_scale_shift(weight):
973
973
converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
974
974
975
975
return converted_state_dict
976
+
977
+
978
+ def _convert_hunyuan_video_lora_to_diffusers (original_state_dict ):
979
+ converted_state_dict = {k : original_state_dict .pop (k ) for k in list (original_state_dict .keys ())}
980
+
981
+ def remap_norm_scale_shift_ (key , state_dict ):
982
+ weight = state_dict .pop (key )
983
+ shift , scale = weight .chunk (2 , dim = 0 )
984
+ new_weight = torch .cat ([scale , shift ], dim = 0 )
985
+ state_dict [key .replace ("final_layer.adaLN_modulation.1" , "norm_out.linear" )] = new_weight
986
+
987
+ def remap_txt_in_ (key , state_dict ):
988
+ def rename_key (key ):
989
+ new_key = key .replace ("individual_token_refiner.blocks" , "token_refiner.refiner_blocks" )
990
+ new_key = new_key .replace ("adaLN_modulation.1" , "norm_out.linear" )
991
+ new_key = new_key .replace ("txt_in" , "context_embedder" )
992
+ new_key = new_key .replace ("t_embedder.mlp.0" , "time_text_embed.timestep_embedder.linear_1" )
993
+ new_key = new_key .replace ("t_embedder.mlp.2" , "time_text_embed.timestep_embedder.linear_2" )
994
+ new_key = new_key .replace ("c_embedder" , "time_text_embed.text_embedder" )
995
+ new_key = new_key .replace ("mlp" , "ff" )
996
+ return new_key
997
+
998
+ if "self_attn_qkv" in key :
999
+ weight = state_dict .pop (key )
1000
+ to_q , to_k , to_v = weight .chunk (3 , dim = 0 )
1001
+ state_dict [rename_key (key .replace ("self_attn_qkv" , "attn.to_q" ))] = to_q
1002
+ state_dict [rename_key (key .replace ("self_attn_qkv" , "attn.to_k" ))] = to_k
1003
+ state_dict [rename_key (key .replace ("self_attn_qkv" , "attn.to_v" ))] = to_v
1004
+ else :
1005
+ state_dict [rename_key (key )] = state_dict .pop (key )
1006
+
1007
+ def remap_img_attn_qkv_ (key , state_dict ):
1008
+ weight = state_dict .pop (key )
1009
+ if "lora_A" in key :
1010
+ state_dict [key .replace ("img_attn_qkv" , "attn.to_q" )] = weight
1011
+ state_dict [key .replace ("img_attn_qkv" , "attn.to_k" )] = weight
1012
+ state_dict [key .replace ("img_attn_qkv" , "attn.to_v" )] = weight
1013
+ else :
1014
+ to_q , to_k , to_v = weight .chunk (3 , dim = 0 )
1015
+ state_dict [key .replace ("img_attn_qkv" , "attn.to_q" )] = to_q
1016
+ state_dict [key .replace ("img_attn_qkv" , "attn.to_k" )] = to_k
1017
+ state_dict [key .replace ("img_attn_qkv" , "attn.to_v" )] = to_v
1018
+
1019
+ def remap_txt_attn_qkv_ (key , state_dict ):
1020
+ weight = state_dict .pop (key )
1021
+ if "lora_A" in key :
1022
+ state_dict [key .replace ("txt_attn_qkv" , "attn.add_q_proj" )] = weight
1023
+ state_dict [key .replace ("txt_attn_qkv" , "attn.add_k_proj" )] = weight
1024
+ state_dict [key .replace ("txt_attn_qkv" , "attn.add_v_proj" )] = weight
1025
+ else :
1026
+ to_q , to_k , to_v = weight .chunk (3 , dim = 0 )
1027
+ state_dict [key .replace ("txt_attn_qkv" , "attn.add_q_proj" )] = to_q
1028
+ state_dict [key .replace ("txt_attn_qkv" , "attn.add_k_proj" )] = to_k
1029
+ state_dict [key .replace ("txt_attn_qkv" , "attn.add_v_proj" )] = to_v
1030
+
1031
+ def remap_single_transformer_blocks_ (key , state_dict ):
1032
+ hidden_size = 3072
1033
+
1034
+ if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key :
1035
+ linear1_weight = state_dict .pop (key )
1036
+ if "lora_A" in key :
1037
+ new_key = key .replace ("single_blocks" , "single_transformer_blocks" ).removesuffix (
1038
+ ".linear1.lora_A.weight"
1039
+ )
1040
+ state_dict [f"{ new_key } .attn.to_q.lora_A.weight" ] = linear1_weight
1041
+ state_dict [f"{ new_key } .attn.to_k.lora_A.weight" ] = linear1_weight
1042
+ state_dict [f"{ new_key } .attn.to_v.lora_A.weight" ] = linear1_weight
1043
+ state_dict [f"{ new_key } .proj_mlp.lora_A.weight" ] = linear1_weight
1044
+ else :
1045
+ split_size = (hidden_size , hidden_size , hidden_size , linear1_weight .size (0 ) - 3 * hidden_size )
1046
+ q , k , v , mlp = torch .split (linear1_weight , split_size , dim = 0 )
1047
+ new_key = key .replace ("single_blocks" , "single_transformer_blocks" ).removesuffix (
1048
+ ".linear1.lora_B.weight"
1049
+ )
1050
+ state_dict [f"{ new_key } .attn.to_q.lora_B.weight" ] = q
1051
+ state_dict [f"{ new_key } .attn.to_k.lora_B.weight" ] = k
1052
+ state_dict [f"{ new_key } .attn.to_v.lora_B.weight" ] = v
1053
+ state_dict [f"{ new_key } .proj_mlp.lora_B.weight" ] = mlp
1054
+
1055
+ elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key :
1056
+ linear1_bias = state_dict .pop (key )
1057
+ if "lora_A" in key :
1058
+ new_key = key .replace ("single_blocks" , "single_transformer_blocks" ).removesuffix (
1059
+ ".linear1.lora_A.bias"
1060
+ )
1061
+ state_dict [f"{ new_key } .attn.to_q.lora_A.bias" ] = linear1_bias
1062
+ state_dict [f"{ new_key } .attn.to_k.lora_A.bias" ] = linear1_bias
1063
+ state_dict [f"{ new_key } .attn.to_v.lora_A.bias" ] = linear1_bias
1064
+ state_dict [f"{ new_key } .proj_mlp.lora_A.bias" ] = linear1_bias
1065
+ else :
1066
+ split_size = (hidden_size , hidden_size , hidden_size , linear1_bias .size (0 ) - 3 * hidden_size )
1067
+ q_bias , k_bias , v_bias , mlp_bias = torch .split (linear1_bias , split_size , dim = 0 )
1068
+ new_key = key .replace ("single_blocks" , "single_transformer_blocks" ).removesuffix (
1069
+ ".linear1.lora_B.bias"
1070
+ )
1071
+ state_dict [f"{ new_key } .attn.to_q.lora_B.bias" ] = q_bias
1072
+ state_dict [f"{ new_key } .attn.to_k.lora_B.bias" ] = k_bias
1073
+ state_dict [f"{ new_key } .attn.to_v.lora_B.bias" ] = v_bias
1074
+ state_dict [f"{ new_key } .proj_mlp.lora_B.bias" ] = mlp_bias
1075
+
1076
+ else :
1077
+ new_key = key .replace ("single_blocks" , "single_transformer_blocks" )
1078
+ new_key = new_key .replace ("linear2" , "proj_out" )
1079
+ new_key = new_key .replace ("q_norm" , "attn.norm_q" )
1080
+ new_key = new_key .replace ("k_norm" , "attn.norm_k" )
1081
+ state_dict [new_key ] = state_dict .pop (key )
1082
+
1083
+ TRANSFORMER_KEYS_RENAME_DICT = {
1084
+ "img_in" : "x_embedder" ,
1085
+ "time_in.mlp.0" : "time_text_embed.timestep_embedder.linear_1" ,
1086
+ "time_in.mlp.2" : "time_text_embed.timestep_embedder.linear_2" ,
1087
+ "guidance_in.mlp.0" : "time_text_embed.guidance_embedder.linear_1" ,
1088
+ "guidance_in.mlp.2" : "time_text_embed.guidance_embedder.linear_2" ,
1089
+ "vector_in.in_layer" : "time_text_embed.text_embedder.linear_1" ,
1090
+ "vector_in.out_layer" : "time_text_embed.text_embedder.linear_2" ,
1091
+ "double_blocks" : "transformer_blocks" ,
1092
+ "img_attn_q_norm" : "attn.norm_q" ,
1093
+ "img_attn_k_norm" : "attn.norm_k" ,
1094
+ "img_attn_proj" : "attn.to_out.0" ,
1095
+ "txt_attn_q_norm" : "attn.norm_added_q" ,
1096
+ "txt_attn_k_norm" : "attn.norm_added_k" ,
1097
+ "txt_attn_proj" : "attn.to_add_out" ,
1098
+ "img_mod.linear" : "norm1.linear" ,
1099
+ "img_norm1" : "norm1.norm" ,
1100
+ "img_norm2" : "norm2" ,
1101
+ "img_mlp" : "ff" ,
1102
+ "txt_mod.linear" : "norm1_context.linear" ,
1103
+ "txt_norm1" : "norm1.norm" ,
1104
+ "txt_norm2" : "norm2_context" ,
1105
+ "txt_mlp" : "ff_context" ,
1106
+ "self_attn_proj" : "attn.to_out.0" ,
1107
+ "modulation.linear" : "norm.linear" ,
1108
+ "pre_norm" : "norm.norm" ,
1109
+ "final_layer.norm_final" : "norm_out.norm" ,
1110
+ "final_layer.linear" : "proj_out" ,
1111
+ "fc1" : "net.0.proj" ,
1112
+ "fc2" : "net.2" ,
1113
+ "input_embedder" : "proj_in" ,
1114
+ }
1115
+
1116
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
1117
+ "txt_in" : remap_txt_in_ ,
1118
+ "img_attn_qkv" : remap_img_attn_qkv_ ,
1119
+ "txt_attn_qkv" : remap_txt_attn_qkv_ ,
1120
+ "single_blocks" : remap_single_transformer_blocks_ ,
1121
+ "final_layer.adaLN_modulation.1" : remap_norm_scale_shift_ ,
1122
+ }
1123
+
1124
+ # Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
1125
+ # and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
1126
+ # sure that both follow the same initial format by stripping off the "transformer." prefix.
1127
+ for key in list (converted_state_dict .keys ()):
1128
+ if key .startswith ("transformer." ):
1129
+ converted_state_dict [key [len ("transformer." ) :]] = converted_state_dict .pop (key )
1130
+ if key .startswith ("diffusion_model." ):
1131
+ converted_state_dict [key [len ("diffusion_model." ) :]] = converted_state_dict .pop (key )
1132
+
1133
+ # Rename and remap the state dict keys
1134
+ for key in list (converted_state_dict .keys ()):
1135
+ new_key = key [:]
1136
+ for replace_key , rename_key in TRANSFORMER_KEYS_RENAME_DICT .items ():
1137
+ new_key = new_key .replace (replace_key , rename_key )
1138
+ converted_state_dict [new_key ] = converted_state_dict .pop (key )
1139
+
1140
+ for key in list (converted_state_dict .keys ()):
1141
+ for special_key , handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP .items ():
1142
+ if special_key not in key :
1143
+ continue
1144
+ handler_fn_inplace (key , converted_state_dict )
1145
+
1146
+ # Add back the "transformer." prefix
1147
+ for key in list (converted_state_dict .keys ()):
1148
+ converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
1149
+
1150
+ return converted_state_dict
0 commit comments