Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dpm-solver support (much faster than plms) #440

Merged
merged 2 commits into from
Nov 16, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ldm/models/diffusion/dpm_solver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .sampler import DPMSolverSampler
1,184 changes: 1,184 additions & 0 deletions ldm/models/diffusion/dpm_solver/dpm_solver.py

Large diffs are not rendered by default.

82 changes: 82 additions & 0 deletions ldm/models/diffusion/dpm_solver/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""SAMPLING ONLY."""

import torch

from .solver import NoiseScheduleVP, model_wrapper, DPM_Solver


class DPMSolverSampler(object):
def __init__(self, model, **kwargs):
super().__init__()
self.model = model
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)

@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

# sampling
C, H, W = shape
size = (batch_size, C, H, W)

# print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')

device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T

ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)

model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type="noise",
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)

dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)

return x.to(device), None
10 changes: 9 additions & 1 deletion scripts/txt2img.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler

from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
@@ -132,6 +133,11 @@ def main():
action='store_true',
help="use plms sampling",
)
parser.add_argument(
"--dpm_solver",
action='store_true',
help="use dpm_solver sampling",
)
parser.add_argument(
"--laion400m",
action='store_true',
@@ -242,7 +248,9 @@ def main():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

if opt.plms:
if opt.dpm_solver:
sampler = DPMSolverSampler(model)
elif opt.plms:
sampler = PLMSSampler(model)
else:
sampler = DDIMSampler(model)