Skip to content

Commit 9bd7580

Browse files
Add multistep DPM-Solver discrete scheduler (huggingface#1132)
* add dpmsolver discrete pytorch scheduler * fix some typos in dpm-solver pytorch * add dpm-solver pytorch in stable-diffusion pipeline * add jax/flax version dpm-solver * change code style * change code style * add docs * add `add_noise` method for dpmsolver * add pytorch unit test for dpmsolver * add dummy object for pytorch dpmsolver * Update src/diffusers/schedulers/scheduling_dpmsolver_discrete.py Co-authored-by: Suraj Patil <[email protected]> * Update tests/test_config.py Co-authored-by: Suraj Patil <[email protected]> * Update tests/test_config.py Co-authored-by: Suraj Patil <[email protected]> * resolve the code comments * rename the file * change class name * fix code style * add auto docs for dpmsolver multistep * add more explanations for the stabilizing trick (for steps < 15) * delete the dummy file * change the API name of predict_epsilon, algorithm_type and solver_type * add compatible lists Co-authored-by: Suraj Patil <[email protected]>
1 parent 3f3b89b commit 9bd7580

14 files changed

+1154
-4
lines changed

Diff for: __init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from .schedulers import (
4343
DDIMScheduler,
4444
DDPMScheduler,
45+
DPMSolverMultistepScheduler,
4546
EulerAncestralDiscreteScheduler,
4647
EulerDiscreteScheduler,
4748
IPNDMScheduler,
@@ -92,6 +93,7 @@
9293
from .schedulers import (
9394
FlaxDDIMScheduler,
9495
FlaxDDPMScheduler,
96+
FlaxDPMSolverMultistepScheduler,
9597
FlaxKarrasVeScheduler,
9698
FlaxLMSDiscreteScheduler,
9799
FlaxPNDMScheduler,

Diff for: pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
1616
from ...pipeline_flax_utils import FlaxDiffusionPipeline
17-
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler
17+
from ...schedulers import (
18+
FlaxDDIMScheduler,
19+
FlaxDPMSolverMultistepScheduler,
20+
FlaxLMSDiscreteScheduler,
21+
FlaxPNDMScheduler,
22+
)
1823
from ...utils import logging
1924
from . import FlaxStableDiffusionPipelineOutput
2025
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
@@ -43,7 +48,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
4348
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
4449
scheduler ([`SchedulerMixin`]):
4550
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
46-
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
51+
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
52+
[`FlaxDPMSolverMultistepScheduler`].
4753
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
4854
Classification module that estimates whether generated images could be considered offensive or harmful.
4955
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
@@ -57,7 +63,9 @@ def __init__(
5763
text_encoder: FlaxCLIPTextModel,
5864
tokenizer: CLIPTokenizer,
5965
unet: FlaxUNet2DConditionModel,
60-
scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler],
66+
scheduler: Union[
67+
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
68+
],
6169
safety_checker: FlaxStableDiffusionSafetyChecker,
6270
feature_extractor: CLIPFeatureExtractor,
6371
dtype: jnp.dtype = jnp.float32,

Diff for: pipelines/stable_diffusion/pipeline_stable_diffusion.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ...pipeline_utils import DiffusionPipeline
1212
from ...schedulers import (
1313
DDIMScheduler,
14+
DPMSolverMultistepScheduler,
1415
EulerAncestralDiscreteScheduler,
1516
EulerDiscreteScheduler,
1617
LMSDiscreteScheduler,
@@ -59,7 +60,12 @@ def __init__(
5960
tokenizer: CLIPTokenizer,
6061
unet: UNet2DConditionModel,
6162
scheduler: Union[
62-
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
63+
DDIMScheduler,
64+
PNDMScheduler,
65+
LMSDiscreteScheduler,
66+
EulerDiscreteScheduler,
67+
EulerAncestralDiscreteScheduler,
68+
DPMSolverMultistepScheduler,
6369
],
6470
safety_checker: StableDiffusionSafetyChecker,
6571
feature_extractor: CLIPFeatureExtractor,

Diff for: schedulers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
if is_torch_available():
2020
from .scheduling_ddim import DDIMScheduler
2121
from .scheduling_ddpm import DDPMScheduler
22+
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
2223
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
2324
from .scheduling_euler_discrete import EulerDiscreteScheduler
2425
from .scheduling_ipndm import IPNDMScheduler
@@ -35,6 +36,7 @@
3536
if is_flax_available():
3637
from .scheduling_ddim_flax import FlaxDDIMScheduler
3738
from .scheduling_ddpm_flax import FlaxDDPMScheduler
39+
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
3840
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
3941
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
4042
from .scheduling_pndm_flax import FlaxPNDMScheduler

Diff for: schedulers/scheduling_ddim.py

+1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
115115
"LMSDiscreteScheduler",
116116
"EulerDiscreteScheduler",
117117
"EulerAncestralDiscreteScheduler",
118+
"DPMSolverMultistepScheduler",
118119
]
119120

120121
@register_to_config

Diff for: schedulers/scheduling_ddpm.py

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
108108
"LMSDiscreteScheduler",
109109
"EulerDiscreteScheduler",
110110
"EulerAncestralDiscreteScheduler",
111+
"DPMSolverMultistepScheduler",
111112
]
112113

113114
@register_to_config

0 commit comments

Comments
 (0)