@@ -94,7 +94,7 @@ def get_model(self, model_name:str):
94
94
'hash' : hash
95
95
}
96
96
97
- def default_model (self ) -> str :
97
+ def default_model (self ) -> str | None :
98
98
'''
99
99
Returns the name of the default model, or None
100
100
if none is defined.
@@ -191,13 +191,6 @@ def _load_model(self, model_name:str):
191
191
return None
192
192
193
193
mconfig = self .config [model_name ]
194
- config = mconfig .config
195
- weights = mconfig .weights
196
- vae = mconfig .get ('vae' ,None )
197
- width = mconfig .width
198
- height = mconfig .height
199
-
200
- print (f'>> Loading { model_name } from { weights } ' )
201
194
202
195
# for usage statistics
203
196
if self ._has_cuda ():
@@ -207,15 +200,44 @@ def _load_model(self, model_name:str):
207
200
tic = time .time ()
208
201
209
202
# this does the work
210
- c = OmegaConf .load (config )
211
- with open (weights ,'rb' ) as f :
203
+ model_format = mconfig .get ('format' , 'ckpt' )
204
+ if model_format == 'ckpt' :
205
+ weights = mconfig .weights
206
+ print (f'>> Loading { model_name } from { weights } ' )
207
+ model , width , height , model_hash = self ._load_ckpt_model (mconfig )
208
+ elif model_format == 'diffusers' :
209
+ model , width , height , model_hash = self ._load_diffusers_model (mconfig )
210
+ else :
211
+ raise NotImplementedError (f"Unknown model format { model_name } : { model_format } " )
212
+
213
+ # usage statistics
214
+ toc = time .time ()
215
+ print (f'>> Model loaded in' , '%4.2fs' % (toc - tic ))
216
+ if self ._has_cuda ():
217
+ print (
218
+ '>> Max VRAM used to load the model:' ,
219
+ '%4.2fG' % (torch .cuda .max_memory_allocated () / 1e9 ),
220
+ '\n >> Current VRAM usage:'
221
+ '%4.2fG' % (torch .cuda .memory_allocated () / 1e9 ),
222
+ )
223
+ return model , width , height , model_hash
224
+
225
+ def _load_ckpt_model (self , mconfig ):
226
+ config = mconfig .config
227
+ weights = mconfig .weights
228
+ vae = mconfig .get ('vae' , None )
229
+ width = mconfig .width
230
+ height = mconfig .height
231
+
232
+ c = OmegaConf .load (config )
233
+ with open (weights , 'rb' ) as f :
212
234
weight_bytes = f .read ()
213
- model_hash = self ._cached_sha256 (weights ,weight_bytes )
235
+ model_hash = self ._cached_sha256 (weights , weight_bytes )
214
236
pl_sd = torch .load (io .BytesIO (weight_bytes ), map_location = 'cpu' )
215
237
del weight_bytes
216
- sd = pl_sd ['state_dict' ]
238
+ sd = pl_sd ['state_dict' ]
217
239
model = instantiate_from_config (c .model )
218
- m , u = model .load_state_dict (sd , strict = False )
240
+ m , u = model .load_state_dict (sd , strict = False )
219
241
220
242
if self .precision == 'float16' :
221
243
print (' | Using faster float16 precision' )
@@ -243,18 +265,11 @@ def _load_model(self, model_name:str):
243
265
if isinstance (m , (torch .nn .Conv2d , torch .nn .ConvTranspose2d )):
244
266
m ._orig_padding_mode = m .padding_mode
245
267
246
- # usage statistics
247
- toc = time .time ()
248
- print (f'>> Model loaded in' , '%4.2fs' % (toc - tic ))
249
- if self ._has_cuda ():
250
- print (
251
- '>> Max VRAM used to load the model:' ,
252
- '%4.2fG' % (torch .cuda .max_memory_allocated () / 1e9 ),
253
- '\n >> Current VRAM usage:'
254
- '%4.2fG' % (torch .cuda .memory_allocated () / 1e9 ),
255
- )
256
268
return model , width , height , model_hash
257
-
269
+
270
+ def _load_diffusers_model (self , mconfig ):
271
+ raise NotImplementedError () # return pipeline, width, height, model_hash
272
+
258
273
def offload_model (self , model_name :str ):
259
274
'''
260
275
Offload the indicated model to CPU. Will call
0 commit comments