Skip to content

Commit 2e8668f

Browse files
Correct controlnet out of list error (#3928)
* Correct controlnet out of list error * Apply suggestions from code review * correct tests * correct tests * fix * test all * Apply suggestions from code review * test all * test all * Apply suggestions from code review * Apply suggestions from code review * fix more tests * Fix more * Apply suggestions from code review * finish * Apply suggestions from code review * Update src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py * finish
1 parent b298484 commit 2e8668f

27 files changed

+225
-51
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -947,9 +947,9 @@ def __call__(
947947

948948
# 7.1 Create tensor stating which controlnets to keep
949949
controlnet_keep = []
950-
for i in range(num_inference_steps):
950+
for i in range(len(timesteps)):
951951
keeps = [
952-
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
952+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
953953
for s, e in zip(control_guidance_start, control_guidance_end)
954954
]
955955
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,9 +1040,9 @@ def __call__(
10401040

10411041
# 7.1 Create tensor stating which controlnets to keep
10421042
controlnet_keep = []
1043-
for i in range(num_inference_steps):
1043+
for i in range(len(timesteps)):
10441044
keeps = [
1045-
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
1045+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
10461046
for s, e in zip(control_guidance_start, control_guidance_end)
10471047
]
10481048
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,9 +1275,9 @@ def __call__(
12751275

12761276
# 7.1 Create tensor stating which controlnets to keep
12771277
controlnet_keep = []
1278-
for i in range(num_inference_steps):
1278+
for i in range(len(timesteps)):
12791279
keeps = [
1280-
1.0 - float(i / num_inference_steps < s or (i + 1) / num_inference_steps > e)
1280+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
12811281
for s, e in zip(control_guidance_start, control_guidance_end)
12821282
]
12831283
controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def __call__(
374374
# predicted_original_sample instead of the noise_pred. So we need to compute the
375375
# predicted_original_sample here if we are using a karras style scheduler.
376376
if scheduler_is_in_sigma_space:
377-
step_index = (self.scheduler.timesteps == t).nonzero().item()
377+
step_index = (self.scheduler.timesteps == t).nonzero()[0].item()
378378
sigma = self.scheduler.sigmas[step_index]
379379
noise_pred = latent_model_input - sigma * noise_pred
380380

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
103103
lower_order_final (`bool`, default `True`):
104104
whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically
105105
find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10.
106-
106+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
107+
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
108+
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
109+
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
107110
"""
108111

109112
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -125,6 +128,7 @@ def __init__(
125128
algorithm_type: str = "deis",
126129
solver_type: str = "logrho",
127130
lower_order_final: bool = True,
131+
use_karras_sigmas: Optional[bool] = False,
128132
):
129133
if trained_betas is not None:
130134
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -188,6 +192,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
188192
.astype(np.int64)
189193
)
190194

195+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
196+
if self.config.use_karras_sigmas:
197+
log_sigmas = np.log(sigmas)
198+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
199+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
200+
timesteps = np.flip(timesteps).copy().astype(np.int64)
201+
202+
self.sigmas = torch.from_numpy(sigmas)
203+
191204
# when num_inference_steps == num_train_timesteps, we can end up with
192205
# duplicates in timesteps.
193206
_, unique_indices = np.unique(timesteps, return_index=True)

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ def __init__(
203203
self.timesteps = torch.from_numpy(timesteps)
204204
self.model_outputs = [None] * solver_order
205205
self.lower_order_nums = 0
206-
self.use_karras_sigmas = use_karras_sigmas
207206

208207
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
209208
"""
@@ -225,13 +224,15 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
225224
.astype(np.int64)
226225
)
227226

228-
if self.use_karras_sigmas:
229-
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
227+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
228+
if self.config.use_karras_sigmas:
230229
log_sigmas = np.log(sigmas)
231230
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
232231
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
233232
timesteps = np.flip(timesteps).copy().astype(np.int64)
234233

234+
self.sigmas = torch.from_numpy(sigmas)
235+
235236
# when num_inference_steps == num_train_timesteps, we can end up with
236237
# duplicates in timesteps.
237238
_, unique_indices = np.unique(timesteps, return_index=True)

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def __init__(
202202
self.model_outputs = [None] * solver_order
203203
self.sample = None
204204
self.order_list = self.get_order_list(num_train_timesteps)
205-
self.use_karras_sigmas = use_karras_sigmas
206205

207206
def get_order_list(self, num_inference_steps: int) -> List[int]:
208207
"""
@@ -259,13 +258,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
259258
.astype(np.int64)
260259
)
261260

262-
if self.use_karras_sigmas:
263-
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
261+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
262+
if self.config.use_karras_sigmas:
264263
log_sigmas = np.log(sigmas)
265264
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
266265
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
267266
timesteps = np.flip(timesteps).copy().astype(np.int64)
268267

268+
self.sigmas = torch.from_numpy(sigmas)
269+
269270
self.timesteps = torch.from_numpy(timesteps).to(device)
270271
self.model_outputs = [None] * self.config.solver_order
271272
self.sample = None

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
117117
by disable the corrector at the first few steps (e.g., disable_corrector=[0])
118118
solver_p (`SchedulerMixin`, default `None`):
119119
can be any other scheduler. If specified, the algorithm will become solver_p + UniC.
120+
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
121+
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
122+
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
123+
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
120124
"""
121125

122126
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -140,6 +144,7 @@ def __init__(
140144
lower_order_final: bool = True,
141145
disable_corrector: List[int] = [],
142146
solver_p: SchedulerMixin = None,
147+
use_karras_sigmas: Optional[bool] = False,
143148
):
144149
if trained_betas is not None:
145150
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -201,6 +206,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
201206
.astype(np.int64)
202207
)
203208

209+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
210+
if self.config.use_karras_sigmas:
211+
log_sigmas = np.log(sigmas)
212+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
213+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
214+
timesteps = np.flip(timesteps).copy().astype(np.int64)
215+
216+
self.sigmas = torch.from_numpy(sigmas)
217+
204218
# when num_inference_steps == num_train_timesteps, we can end up with
205219
# duplicates in timesteps.
206220
_, unique_indices = np.unique(timesteps, return_index=True)

tests/pipelines/altdiffusion/test_alt_diffusion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@
2929
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
3030

3131
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
32-
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
32+
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
3333

3434

3535
enable_full_determinism()
3636

3737

38-
class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
38+
class AltDiffusionPipelineFastTests(
39+
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
40+
):
3941
pipeline_class = AltDiffusionPipeline
4042
params = TEXT_TO_IMAGE_PARAMS
4143
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS

tests/pipelines/controlnet/test_controlnet.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@
4646
TEXT_TO_IMAGE_IMAGE_PARAMS,
4747
TEXT_TO_IMAGE_PARAMS,
4848
)
49-
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
49+
from ..test_pipelines_common import (
50+
PipelineKarrasSchedulerTesterMixin,
51+
PipelineLatentTesterMixin,
52+
PipelineTesterMixin,
53+
)
5054

5155

5256
enable_full_determinism()
@@ -97,7 +101,9 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
97101
out_queue.join()
98102

99103

100-
class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
104+
class ControlNetPipelineFastTests(
105+
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
106+
):
101107
pipeline_class = StableDiffusionControlNetPipeline
102108
params = TEXT_TO_IMAGE_PARAMS
103109
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -207,7 +213,9 @@ def test_inference_batch_single_identical(self):
207213
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
208214

209215

210-
class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
216+
class StableDiffusionMultiControlNetPipelineFastTests(
217+
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
218+
):
211219
pipeline_class = StableDiffusionControlNetPipeline
212220
params = TEXT_TO_IMAGE_PARAMS
213221
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS

tests/pipelines/controlnet/test_controlnet_img2img.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,19 @@
4242
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
4343
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
4444
)
45-
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
45+
from ..test_pipelines_common import (
46+
PipelineKarrasSchedulerTesterMixin,
47+
PipelineLatentTesterMixin,
48+
PipelineTesterMixin,
49+
)
4650

4751

4852
enable_full_determinism()
4953

5054

51-
class ControlNetImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
55+
class ControlNetImg2ImgPipelineFastTests(
56+
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
57+
):
5258
pipeline_class = StableDiffusionControlNetImg2ImgPipeline
5359
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
5460
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
@@ -161,7 +167,9 @@ def test_inference_batch_single_identical(self):
161167
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
162168

163169

164-
class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
170+
class StableDiffusionMultiControlNetPipelineFastTests(
171+
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
172+
):
165173
pipeline_class = StableDiffusionControlNetImg2ImgPipeline
166174
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
167175
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS

tests/pipelines/controlnet/test_controlnet_inpaint.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,19 @@
4242
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
4343
TEXT_TO_IMAGE_IMAGE_PARAMS,
4444
)
45-
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
45+
from ..test_pipelines_common import (
46+
PipelineKarrasSchedulerTesterMixin,
47+
PipelineLatentTesterMixin,
48+
PipelineTesterMixin,
49+
)
4650

4751

4852
enable_full_determinism()
4953

5054

51-
class ControlNetInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
55+
class ControlNetInpaintPipelineFastTests(
56+
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
57+
):
5258
pipeline_class = StableDiffusionControlNetInpaintPipeline
5359
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
5460
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
@@ -237,7 +243,9 @@ def get_dummy_components(self):
237243
return components
238244

239245

240-
class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
246+
class MultiControlNetInpaintPipelineFastTests(
247+
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
248+
):
241249
pipeline_class = StableDiffusionControlNetInpaintPipeline
242250
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
243251
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from ...models.test_lora_layers import create_unet_lora_layers
5151
from ...models.test_models_unet_2d_condition import create_lora_layers
5252
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
53-
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
53+
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
5454

5555

5656
enable_full_determinism()
@@ -88,7 +88,9 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
8888
out_queue.join()
8989

9090

91-
class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
91+
class StableDiffusionPipelineFastTests(
92+
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
93+
):
9294
pipeline_class = StableDiffusionPipeline
9395
params = TEXT_TO_IMAGE_PARAMS
9496
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS

tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,14 @@
3333
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
3434

3535
from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
36-
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
36+
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
3737

3838

3939
enable_full_determinism()
4040

4141

4242
class StableDiffusionImageVariationPipelineFastTests(
43-
PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
43+
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
4444
):
4545
pipeline_class = StableDiffusionImageVariationPipeline
4646
params = IMAGE_VARIATION_PARAMS

tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
4747
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
4848
)
49-
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
49+
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
5050

5151

5252
enable_full_determinism()
@@ -84,7 +84,9 @@ def _test_img2img_compile(in_queue, out_queue, timeout):
8484
out_queue.join()
8585

8686

87-
class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
87+
class StableDiffusionImg2ImgPipelineFastTests(
88+
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
89+
):
8890
pipeline_class = StableDiffusionImg2ImgPipeline
8991
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
9092
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
from ...models.test_models_unet_2d_condition import create_lora_layers
4444
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
45-
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
45+
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
4646

4747

4848
enable_full_determinism()
@@ -82,7 +82,9 @@ def _test_inpaint_compile(in_queue, out_queue, timeout):
8282
out_queue.join()
8383

8484

85-
class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
85+
class StableDiffusionInpaintPipelineFastTests(
86+
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
87+
):
8688
pipeline_class = StableDiffusionInpaintPipeline
8789
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
8890
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS

tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@
4040
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
4141
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
4242
)
43-
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
43+
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
4444

4545

4646
enable_full_determinism()
4747

4848

4949
class StableDiffusionInstructPix2PixPipelineFastTests(
50-
PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
50+
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
5151
):
5252
pipeline_class = StableDiffusionInstructPix2PixPipeline
5353
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"}

tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,16 @@
3232
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps
3333

3434
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
35-
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
35+
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
3636

3737

3838
enable_full_determinism()
3939

4040

4141
@skip_mps
42-
class StableDiffusionModelEditingPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
42+
class StableDiffusionModelEditingPipelineFastTests(
43+
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
44+
):
4345
pipeline_class = StableDiffusionModelEditingPipeline
4446
params = TEXT_TO_IMAGE_PARAMS
4547
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS

0 commit comments

Comments
 (0)