Skip to content

Commit b39d04d

Browse files
committed
model_cache: add ability to load a diffusers model pipeline
and update associated things in Generate & Generator to not instantly fail when that happens
1 parent 9f5e496 commit b39d04d

File tree

5 files changed

+124
-16
lines changed

5 files changed

+124
-16
lines changed

ldm/generate.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import hashlib
1919
import cv2
2020
import skimage
21+
from diffusers import DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, \
22+
EulerAncestralDiscreteScheduler
2123

2224
from omegaconf import OmegaConf
2325
from ldm.invoke.generator.base import downsampling
@@ -386,7 +388,10 @@ def process_image(image,seed):
386388
width = width or self.width
387389
height = height or self.height
388390

389-
configure_model_padding(model, seamless, seamless_axes)
391+
if isinstance(model, DiffusionPipeline):
392+
configure_model_padding(model.unet, seamless, seamless_axes)
393+
else:
394+
configure_model_padding(model, seamless, seamless_axes)
390395

391396
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
392397
assert threshold >= 0.0, '--threshold must be >=0.0'
@@ -930,9 +935,15 @@ def sample_to_image(self, samples):
930935
def sample_to_lowres_estimated_image(self, samples):
931936
return self._make_base().sample_to_lowres_estimated_image(samples)
932937

938+
def _set_sampler(self):
939+
if isinstance(self.model, DiffusionPipeline):
940+
return self._set_scheduler()
941+
else:
942+
return self._set_sampler_legacy()
943+
933944
# very repetitive code - can this be simplified? The KSampler names are
934945
# consistent, at least
935-
def _set_sampler(self):
946+
def _set_sampler_legacy(self):
936947
msg = f'>> Setting Sampler to {self.sampler_name}'
937948
if self.sampler_name == 'plms':
938949
self.sampler = PLMSSampler(self.model, device=self.device)
@@ -956,6 +967,44 @@ def _set_sampler(self):
956967

957968
print(msg)
958969

970+
def _set_scheduler(self):
971+
msg = f'>> Setting Sampler to {self.sampler_name}'
972+
default = self.model.scheduler
973+
# TODO: Test me! Not all schedulers take the same args.
974+
scheduler_args = dict(
975+
num_train_timesteps=default.num_train_timesteps,
976+
beta_start=default.beta_start,
977+
beta_end=default.beta_end,
978+
beta_schedule=default.beta_schedule,
979+
)
980+
trained_betas = getattr(self.model.scheduler, 'trained_betas')
981+
if trained_betas is not None:
982+
scheduler_args.update(trained_betas=trained_betas)
983+
if self.sampler_name == 'plms':
984+
raise NotImplementedError("What's the diffusers implementation of PLMS?")
985+
elif self.sampler_name == 'ddim':
986+
self.sampler = DDIMScheduler(**scheduler_args)
987+
elif self.sampler_name == 'k_dpm_2_a':
988+
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
989+
elif self.sampler_name == 'k_dpm_2':
990+
raise NotImplementedError("no diffusers implementation of dpm_2 samplers")
991+
elif self.sampler_name == 'k_euler_a':
992+
self.sampler = EulerAncestralDiscreteScheduler(**scheduler_args)
993+
elif self.sampler_name == 'k_euler':
994+
self.sampler = EulerDiscreteScheduler(**scheduler_args)
995+
elif self.sampler_name == 'k_heun':
996+
raise NotImplementedError("no diffusers implementation of Heun's sampler")
997+
elif self.sampler_name == 'k_lms':
998+
self.sampler = LMSDiscreteScheduler(**scheduler_args)
999+
else:
1000+
msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to {default}'
1001+
1002+
print(msg)
1003+
1004+
if not hasattr(self.sampler, 'uses_inpainting_model'):
1005+
# FIXME: terrible kludge!
1006+
self.sampler.uses_inpainting_model = lambda: False
1007+
9591008
def _load_img(self, img)->Image:
9601009
if isinstance(img, Image.Image):
9611010
image = img

ldm/invoke/generator/base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
import torch
1111
from PIL import Image, ImageFilter
12+
from diffusers import DiffusionPipeline
1213
from einops import rearrange
1314
from pytorch_lightning import seed_everything
1415
from tqdm import trange
@@ -24,9 +25,9 @@ class Generator:
2425
downsampling_factor: int
2526
latent_channels: int
2627
precision: str
27-
model: DiffusionWrapper
28+
model: DiffusionWrapper | DiffusionPipeline
2829

29-
def __init__(self, model: DiffusionWrapper, precision: str):
30+
def __init__(self, model: DiffusionWrapper | DiffusionPipeline, precision: str):
3031
self.model = model
3132
self.precision = precision
3233
self.seed = None

ldm/invoke/generator/diffusers_pipeline.py

+28
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import secrets
2+
import warnings
23
from dataclasses import dataclass
34
from typing import List, Optional, Union, Callable
45

@@ -309,6 +310,28 @@ def get_text_embeddings(self,
309310
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
310311
return text_embeddings
311312

313+
def get_learned_conditioning(self, c: List[List[str]], return_tokens=True,
314+
fragment_weights=None, **kwargs):
315+
"""
316+
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
317+
"""
318+
assert return_tokens == True
319+
if fragment_weights:
320+
weights = fragment_weights[0]
321+
if any(weight != 1.0 for weight in weights):
322+
warnings.warn(f"fragment weights not implemented yet {fragment_weights}", stacklevel=2)
323+
324+
if kwargs:
325+
warnings.warn(f"unsupported args {kwargs}", stacklevel=2)
326+
327+
text_fragments = c[0]
328+
text_input = self._tokenize(text_fragments)
329+
330+
with torch.inference_mode():
331+
token_ids = text_input.input_ids.to(self.text_encoder.device)
332+
text_embeddings = self.text_encoder(token_ids)[0]
333+
return text_embeddings, text_input.input_ids
334+
312335
@torch.inference_mode()
313336
def _tokenize(self, prompt: Union[str, List[str]]):
314337
return self.tokenizer(
@@ -319,6 +342,11 @@ def _tokenize(self, prompt: Union[str, List[str]]):
319342
return_tensors="pt",
320343
)
321344

345+
@property
346+
def channels(self) -> int:
347+
"""Compatible with DiffusionWrapper"""
348+
return self.unet.in_channels
349+
322350
def prepare_latents(self, latents, batch_size, height, width, generator, dtype):
323351
# get the initial random noise unless the user supplied it
324352
# Unlike in other pipelines, latents need to be generated in the target device

ldm/invoke/generator/txt2img.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,8 @@ def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
2424
self.perlin = perlin
2525
uc, c, extra_conditioning_info = conditioning
2626

27-
# FIXME: this should probably be either passed in to __init__ instead of model & precision,
28-
# or be constructed in __init__ from those inputs.
29-
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
30-
"runwayml/stable-diffusion-v1-5",
31-
revision="fp16", torch_dtype=torch.float16,
32-
safety_checker=None, # TODO
33-
# scheduler=sampler + ddim_eta, # TODO
34-
# TODO: local_files_only=True
35-
)
36-
pipeline.unet.to("cuda")
37-
pipeline.vae.to("cuda")
27+
pipeline = self.model
28+
# TODO: customize a new pipeline for the given sampler (Scheduler)
3829

3930
def make_image(x_T) -> PIL.Image.Image:
4031
# FIXME: restore free_gpu_mem functionality

ldm/invoke/model_cache.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
below a preset minimum, the least recently used model will be
55
cleared and loaded from disk when next needed.
66
'''
7+
from pathlib import Path
78

89
import torch
910
import os
@@ -18,6 +19,8 @@
1819
from sys import getrefcount
1920
from omegaconf import OmegaConf
2021
from omegaconf.errors import ConfigAttributeError
22+
23+
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
2124
from ldm.util import instantiate_from_config
2225

2326
DEFAULT_MAX_MODELS=2
@@ -268,7 +271,43 @@ def _load_ckpt_model(self, mconfig):
268271
return model, width, height, model_hash
269272

270273
def _load_diffusers_model(self, mconfig):
271-
raise NotImplementedError() # return pipeline, width, height, model_hash
274+
pipeline_args = {}
275+
276+
if 'repo_name' in mconfig:
277+
name_or_path = mconfig['repo_name']
278+
model_hash = "FIXME"
279+
# model_hash = huggingface_hub.get_hf_file_metadata(url).commit_hash
280+
elif 'path' in mconfig:
281+
name_or_path = Path(mconfig['path'])
282+
# FIXME: What should the model_hash be? A hash of the unet weights? Of all files of all
283+
# the submodels hashed together? The commit ID from the repo?
284+
model_hash = "FIXME TOO"
285+
else:
286+
raise ValueError("Model config must specify either repo_name or path.")
287+
288+
print(f'>> Loading diffusers model from {name_or_path}')
289+
290+
if self.precision == 'float16':
291+
print(' | Using faster float16 precision')
292+
pipeline_args.update(revision="fp16", torch_dtype=torch.float16)
293+
else:
294+
# TODO: more accurately, "using the model's default precision."
295+
# How do we find out what that is?
296+
print(' | Using more accurate float32 precision')
297+
298+
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
299+
name_or_path,
300+
safety_checker=None, # TODO
301+
# TODO: alternate VAE
302+
# TODO: local_files_only=True
303+
**pipeline_args
304+
)
305+
pipeline.to(self.device)
306+
307+
width = pipeline.vae.sample_size
308+
height = pipeline.vae.sample_size
309+
310+
return pipeline, width, height, model_hash
272311

273312
def offload_model(self, model_name:str):
274313
'''

0 commit comments

Comments
 (0)