81
81
"timestep_spacing" : "leading" ,
82
82
}
83
83
84
+
85
+ STABLE_CASCADE_DEFAULT_CONFIGS = {
86
+ "stage_c" : {"pretrained_model_name_or_path" : "diffusers/stable-cascade-configs" , "subfolder" : "prior" },
87
+ "stage_c_lite" : {"pretrained_model_name_or_path" : "diffusers/stable-cascade-configs" , "subfolder" : "prior_lite" },
88
+ "stage_b" : {"pretrained_model_name_or_path" : "diffusers/stable-cascade-configs" , "subfolder" : "decoder" },
89
+ "stage_b_lite" : {"pretrained_model_name_or_path" : "diffusers/stable-cascade-configs" , "subfolder" : "decoder_lite" },
90
+ }
91
+
92
+
93
+ def convert_stable_cascade_unet_single_file_to_diffusers (original_state_dict ):
94
+ is_stage_c = "clip_txt_mapper.weight" in original_state_dict
95
+
96
+ if is_stage_c :
97
+ state_dict = {}
98
+ for key in original_state_dict .keys ():
99
+ if key .endswith ("in_proj_weight" ):
100
+ weights = original_state_dict [key ].chunk (3 , 0 )
101
+ state_dict [key .replace ("attn.in_proj_weight" , "to_q.weight" )] = weights [0 ]
102
+ state_dict [key .replace ("attn.in_proj_weight" , "to_k.weight" )] = weights [1 ]
103
+ state_dict [key .replace ("attn.in_proj_weight" , "to_v.weight" )] = weights [2 ]
104
+ elif key .endswith ("in_proj_bias" ):
105
+ weights = original_state_dict [key ].chunk (3 , 0 )
106
+ state_dict [key .replace ("attn.in_proj_bias" , "to_q.bias" )] = weights [0 ]
107
+ state_dict [key .replace ("attn.in_proj_bias" , "to_k.bias" )] = weights [1 ]
108
+ state_dict [key .replace ("attn.in_proj_bias" , "to_v.bias" )] = weights [2 ]
109
+ elif key .endswith ("out_proj.weight" ):
110
+ weights = original_state_dict [key ]
111
+ state_dict [key .replace ("attn.out_proj.weight" , "to_out.0.weight" )] = weights
112
+ elif key .endswith ("out_proj.bias" ):
113
+ weights = original_state_dict [key ]
114
+ state_dict [key .replace ("attn.out_proj.bias" , "to_out.0.bias" )] = weights
115
+ else :
116
+ state_dict [key ] = original_state_dict [key ]
117
+ else :
118
+ state_dict = {}
119
+ for key in original_state_dict .keys ():
120
+ if key .endswith ("in_proj_weight" ):
121
+ weights = original_state_dict [key ].chunk (3 , 0 )
122
+ state_dict [key .replace ("attn.in_proj_weight" , "to_q.weight" )] = weights [0 ]
123
+ state_dict [key .replace ("attn.in_proj_weight" , "to_k.weight" )] = weights [1 ]
124
+ state_dict [key .replace ("attn.in_proj_weight" , "to_v.weight" )] = weights [2 ]
125
+ elif key .endswith ("in_proj_bias" ):
126
+ weights = original_state_dict [key ].chunk (3 , 0 )
127
+ state_dict [key .replace ("attn.in_proj_bias" , "to_q.bias" )] = weights [0 ]
128
+ state_dict [key .replace ("attn.in_proj_bias" , "to_k.bias" )] = weights [1 ]
129
+ state_dict [key .replace ("attn.in_proj_bias" , "to_v.bias" )] = weights [2 ]
130
+ elif key .endswith ("out_proj.weight" ):
131
+ weights = original_state_dict [key ]
132
+ state_dict [key .replace ("attn.out_proj.weight" , "to_out.0.weight" )] = weights
133
+ elif key .endswith ("out_proj.bias" ):
134
+ weights = original_state_dict [key ]
135
+ state_dict [key .replace ("attn.out_proj.bias" , "to_out.0.bias" )] = weights
136
+ # rename clip_mapper to clip_txt_pooled_mapper
137
+ elif key .endswith ("clip_mapper.weight" ):
138
+ weights = original_state_dict [key ]
139
+ state_dict [key .replace ("clip_mapper.weight" , "clip_txt_pooled_mapper.weight" )] = weights
140
+ elif key .endswith ("clip_mapper.bias" ):
141
+ weights = original_state_dict [key ]
142
+ state_dict [key .replace ("clip_mapper.bias" , "clip_txt_pooled_mapper.bias" )] = weights
143
+ else :
144
+ state_dict [key ] = original_state_dict [key ]
145
+
146
+ return state_dict
147
+
148
+
149
+ def infer_stable_cascade_single_file_config (checkpoint ):
150
+ is_stage_c = "clip_txt_mapper.weight" in checkpoint
151
+ is_stage_b = "down_blocks.1.0.channelwise.0.weight" in checkpoint
152
+
153
+ if is_stage_c and (checkpoint ["clip_txt_mapper.weight" ].shape [0 ] == 1536 ):
154
+ config_type = "stage_c_lite"
155
+ elif is_stage_c and (checkpoint ["clip_txt_mapper.weight" ].shape [0 ] == 2048 ):
156
+ config_type = "stage_c"
157
+ elif is_stage_b and checkpoint ["down_blocks.1.0.channelwise.0.weight" ].shape [- 1 ] == 576 :
158
+ config_type = "stage_b_lite"
159
+ elif is_stage_b and checkpoint ["down_blocks.1.0.channelwise.0.weight" ].shape [- 1 ] == 640 :
160
+ config_type = "stage_b"
161
+
162
+ return STABLE_CASCADE_DEFAULT_CONFIGS [config_type ]
163
+
164
+
84
165
DIFFUSERS_TO_LDM_MAPPING = {
85
166
"unet" : {
86
167
"layers" : {
@@ -229,10 +310,34 @@ def fetch_ldm_config_and_checkpoint(
229
310
cache_dir = None ,
230
311
local_files_only = None ,
231
312
revision = None ,
313
+ ):
314
+ checkpoint = load_single_file_model_checkpoint (
315
+ pretrained_model_link_or_path ,
316
+ resume_download = resume_download ,
317
+ force_download = force_download ,
318
+ proxies = proxies ,
319
+ token = token ,
320
+ cache_dir = cache_dir ,
321
+ local_files_only = local_files_only ,
322
+ revision = revision ,
323
+ )
324
+ original_config = fetch_original_config (class_name , checkpoint , original_config_file )
325
+
326
+ return original_config , checkpoint
327
+
328
+
329
+ def load_single_file_model_checkpoint (
330
+ pretrained_model_link_or_path ,
331
+ resume_download = False ,
332
+ force_download = False ,
333
+ proxies = None ,
334
+ token = None ,
335
+ cache_dir = None ,
336
+ local_files_only = None ,
337
+ revision = None ,
232
338
):
233
339
if os .path .isfile (pretrained_model_link_or_path ):
234
340
checkpoint = load_state_dict (pretrained_model_link_or_path )
235
-
236
341
else :
237
342
repo_id , weights_name = _extract_repo_id_and_weights_name (pretrained_model_link_or_path )
238
343
checkpoint_path = _get_model_file (
@@ -252,9 +357,7 @@ def fetch_ldm_config_and_checkpoint(
252
357
while "state_dict" in checkpoint :
253
358
checkpoint = checkpoint ["state_dict" ]
254
359
255
- original_config = fetch_original_config (class_name , checkpoint , original_config_file )
256
-
257
- return original_config , checkpoint
360
+ return checkpoint
258
361
259
362
260
363
def infer_original_config_file (class_name , checkpoint ):
0 commit comments