Skip to content

Commit 9f5e496

Browse files
committed
refactor(model_cache): factor out load_ckpt
1 parent a267b45 commit 9f5e496

File tree

1 file changed

+39
-24
lines changed

1 file changed

+39
-24
lines changed

ldm/invoke/model_cache.py

+39-24
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def get_model(self, model_name:str):
9494
'hash': hash
9595
}
9696

97-
def default_model(self) -> str:
97+
def default_model(self) -> str | None:
9898
'''
9999
Returns the name of the default model, or None
100100
if none is defined.
@@ -191,13 +191,6 @@ def _load_model(self, model_name:str):
191191
return None
192192

193193
mconfig = self.config[model_name]
194-
config = mconfig.config
195-
weights = mconfig.weights
196-
vae = mconfig.get('vae',None)
197-
width = mconfig.width
198-
height = mconfig.height
199-
200-
print(f'>> Loading {model_name} from {weights}')
201194

202195
# for usage statistics
203196
if self._has_cuda():
@@ -207,15 +200,44 @@ def _load_model(self, model_name:str):
207200
tic = time.time()
208201

209202
# this does the work
210-
c = OmegaConf.load(config)
211-
with open(weights,'rb') as f:
203+
model_format = mconfig.get('format', 'ckpt')
204+
if model_format == 'ckpt':
205+
weights = mconfig.weights
206+
print(f'>> Loading {model_name} from {weights}')
207+
model, width, height, model_hash = self._load_ckpt_model(mconfig)
208+
elif model_format == 'diffusers':
209+
model, width, height, model_hash = self._load_diffusers_model(mconfig)
210+
else:
211+
raise NotImplementedError(f"Unknown model format {model_name}: {model_format}")
212+
213+
# usage statistics
214+
toc = time.time()
215+
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
216+
if self._has_cuda():
217+
print(
218+
'>> Max VRAM used to load the model:',
219+
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
220+
'\n>> Current VRAM usage:'
221+
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
222+
)
223+
return model, width, height, model_hash
224+
225+
def _load_ckpt_model(self, mconfig):
226+
config = mconfig.config
227+
weights = mconfig.weights
228+
vae = mconfig.get('vae', None)
229+
width = mconfig.width
230+
height = mconfig.height
231+
232+
c = OmegaConf.load(config)
233+
with open(weights, 'rb') as f:
212234
weight_bytes = f.read()
213-
model_hash = self._cached_sha256(weights,weight_bytes)
235+
model_hash = self._cached_sha256(weights, weight_bytes)
214236
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
215237
del weight_bytes
216-
sd = pl_sd['state_dict']
238+
sd = pl_sd['state_dict']
217239
model = instantiate_from_config(c.model)
218-
m, u = model.load_state_dict(sd, strict=False)
240+
m, u = model.load_state_dict(sd, strict=False)
219241

220242
if self.precision == 'float16':
221243
print(' | Using faster float16 precision')
@@ -243,18 +265,11 @@ def _load_model(self, model_name:str):
243265
if isinstance(m, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
244266
m._orig_padding_mode = m.padding_mode
245267

246-
# usage statistics
247-
toc = time.time()
248-
print(f'>> Model loaded in', '%4.2fs' % (toc - tic))
249-
if self._has_cuda():
250-
print(
251-
'>> Max VRAM used to load the model:',
252-
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
253-
'\n>> Current VRAM usage:'
254-
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
255-
)
256268
return model, width, height, model_hash
257-
269+
270+
def _load_diffusers_model(self, mconfig):
271+
raise NotImplementedError() # return pipeline, width, height, model_hash
272+
258273
def offload_model(self, model_name:str):
259274
'''
260275
Offload the indicated model to CPU. Will call

0 commit comments

Comments
 (0)