diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index e2800342e578..283ddb6f19b3 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -13,12 +13,13 @@ # limitations under the License. import importlib +import inspect import warnings from typing import Callable, List, Optional, Union import torch from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser -from k_diffusion.sampling import get_sigmas_karras +from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin @@ -464,6 +465,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, use_karras_sigmas: Optional[bool] = False, + noise_sampler_seed: Optional[int] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -524,6 +526,8 @@ def __call__( Use karras sigmas. For example, specifying `sample_dpmpp_2m` to `set_scheduler` will be equivalent to `DPM++2M` in stable-diffusion-webui. On top of that, setting this option to True will make it `DPM++2M Karras`. + noise_sampler_seed (`int`, *optional*, defaults to `None`): + The random seed to use for the noise sampler. If `None`, a random seed will be generated. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. @@ -608,7 +612,14 @@ def model_fn(x, t): return noise_pred # 8. Run k-diffusion solver - latents = self.sampler(model_fn, latents, sigmas) + sampler_kwargs = {} + + if "noise_sampler" in inspect.signature(self.sampler).parameters: + min_sigma, max_sigma = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) + sampler_kwargs["noise_sampler"] = noise_sampler + + latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py index 4eccb871a0cb..25da13d9f922 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py @@ -104,3 +104,33 @@ def test_stable_diffusion_karras_sigmas(self): ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + def test_stable_diffusion_noise_sampler_seed(self): + sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + sd_pipe.set_scheduler("sample_dpmpp_sde") + + prompt = "A painting of a squirrel eating a burger" + seed = 0 + images1 = sd_pipe( + [prompt], + generator=torch.manual_seed(seed), + noise_sampler_seed=seed, + guidance_scale=9.0, + num_inference_steps=20, + output_type="np", + ).images + images2 = sd_pipe( + [prompt], + generator=torch.manual_seed(seed), + noise_sampler_seed=seed, + guidance_scale=9.0, + num_inference_steps=20, + output_type="np", + ).images + + assert images1.shape == (1, 512, 512, 3) + assert images2.shape == (1, 512, 512, 3) + assert np.abs(images1.flatten() - images2.flatten()).max() < 1e-2