Skip to content

Commit 69f4919

Browse files
authored
Fix passing pooled prompt embeds to Cascade Decoder and Combined Pipeline (#7287)
* update * update * update * update
1 parent ed224f9 commit 69f4919

File tree

6 files changed

+170
-12
lines changed

6 files changed

+170
-12
lines changed

src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,9 @@ def __call__(
289289
guidance_scale: float = 0.0,
290290
negative_prompt: Optional[Union[str, List[str]]] = None,
291291
prompt_embeds: Optional[torch.FloatTensor] = None,
292+
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
292293
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
294+
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
293295
num_images_per_prompt: int = 1,
294296
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
295297
latents: Optional[torch.FloatTensor] = None,
@@ -321,10 +323,17 @@ def __call__(
321323
prompt_embeds (`torch.FloatTensor`, *optional*):
322324
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
323325
provided, text embeddings will be generated from `prompt` input argument.
326+
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
327+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
328+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
324329
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
325330
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
326331
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
327332
argument.
333+
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
334+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
335+
weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input
336+
argument.
328337
num_images_per_prompt (`int`, *optional*, defaults to 1):
329338
The number of images to generate per prompt.
330339
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -378,18 +387,24 @@ def __call__(
378387

379388
# 2. Encode caption
380389
if prompt_embeds is None and negative_prompt_embeds is None:
381-
prompt_embeds, _, negative_prompt_embeds, _ = self.encode_prompt(
390+
_, prompt_embeds_pooled, _, negative_prompt_embeds_pooled = self.encode_prompt(
382391
prompt=prompt,
383392
device=device,
384393
batch_size=batch_size,
385394
num_images_per_prompt=num_images_per_prompt,
386395
do_classifier_free_guidance=self.do_classifier_free_guidance,
387396
negative_prompt=negative_prompt,
388397
prompt_embeds=prompt_embeds,
398+
prompt_embeds_pooled=prompt_embeds_pooled,
389399
negative_prompt_embeds=negative_prompt_embeds,
400+
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
390401
)
402+
403+
# The pooled embeds from the prior are pooled again before being passed to the decoder
391404
prompt_embeds_pooled = (
392-
torch.cat([prompt_embeds, negative_prompt_embeds]) if self.do_classifier_free_guidance else prompt_embeds
405+
torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled])
406+
if self.do_classifier_free_guidance
407+
else prompt_embeds_pooled
393408
)
394409
effnet = (
395410
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])

src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,14 @@ def __call__(
155155
height: int = 512,
156156
width: int = 512,
157157
prior_num_inference_steps: int = 60,
158-
prior_timesteps: Optional[List[float]] = None,
159158
prior_guidance_scale: float = 4.0,
160159
num_inference_steps: int = 12,
161-
decoder_timesteps: Optional[List[float]] = None,
162160
decoder_guidance_scale: float = 0.0,
163161
negative_prompt: Optional[Union[str, List[str]]] = None,
164162
prompt_embeds: Optional[torch.FloatTensor] = None,
163+
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
165164
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
165+
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
166166
num_images_per_prompt: int = 1,
167167
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
168168
latents: Optional[torch.FloatTensor] = None,
@@ -187,10 +187,17 @@ def __call__(
187187
prompt_embeds (`torch.FloatTensor`, *optional*):
188188
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
189189
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
190+
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
191+
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
192+
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
190193
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
191194
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
192195
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
193196
input argument.
197+
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
198+
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
199+
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
200+
input argument.
194201
num_images_per_prompt (`int`, *optional*, defaults to 1):
195202
The number of images to generate per prompt.
196203
height (`int`, *optional*, defaults to 512):
@@ -253,7 +260,6 @@ def __call__(
253260
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
254261
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
255262
"""
256-
257263
prior_outputs = self.prior_pipe(
258264
prompt=prompt if prompt_embeds is None else None,
259265
images=images,
@@ -263,7 +269,9 @@ def __call__(
263269
guidance_scale=prior_guidance_scale,
264270
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
265271
prompt_embeds=prompt_embeds,
272+
prompt_embeds_pooled=prompt_embeds_pooled,
266273
negative_prompt_embeds=negative_prompt_embeds,
274+
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
267275
num_images_per_prompt=num_images_per_prompt,
268276
generator=generator,
269277
latents=latents,
@@ -274,7 +282,9 @@ def __call__(
274282
)
275283
image_embeddings = prior_outputs.image_embeddings
276284
prompt_embeds = prior_outputs.get("prompt_embeds", None)
285+
prompt_embeds_pooled = prior_outputs.get("prompt_embeds_pooled", None)
277286
negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None)
287+
negative_prompt_embeds_pooled = prior_outputs.get("negative_prompt_embeds_pooled", None)
278288

279289
outputs = self.decoder_pipe(
280290
image_embeddings=image_embeddings,
@@ -283,7 +293,9 @@ def __call__(
283293
guidance_scale=decoder_guidance_scale,
284294
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
285295
prompt_embeds=prompt_embeds,
296+
prompt_embeds_pooled=prompt_embeds_pooled,
286297
negative_prompt_embeds=negative_prompt_embeds,
298+
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
287299
generator=generator,
288300
output_type=output_type,
289301
return_dict=return_dict,

src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ class StableCascadePriorPipelineOutput(BaseOutput):
6464

6565
image_embeddings: Union[torch.FloatTensor, np.ndarray]
6666
prompt_embeds: Union[torch.FloatTensor, np.ndarray]
67+
prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray]
6768
negative_prompt_embeds: Union[torch.FloatTensor, np.ndarray]
69+
negative_prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray]
6870

6971

7072
class StableCascadePriorPipeline(DiffusionPipeline):
@@ -305,6 +307,16 @@ def check_inputs(
305307
f" {negative_prompt_embeds.shape}."
306308
)
307309

310+
if prompt_embeds is not None and prompt_embeds_pooled is None:
311+
raise ValueError(
312+
"If `prompt_embeds` are provided, `prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`"
313+
)
314+
315+
if negative_prompt_embeds is not None and negative_prompt_embeds_pooled is None:
316+
raise ValueError(
317+
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`"
318+
)
319+
308320
if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None:
309321
if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape:
310322
raise ValueError(
@@ -339,7 +351,7 @@ def do_classifier_free_guidance(self):
339351
def num_timesteps(self):
340352
return self._num_timesteps
341353

342-
def get_t_condioning(self, t, alphas_cumprod):
354+
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
343355
s = torch.tensor([0.003])
344356
clamp_range = [0, 1]
345357
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
@@ -558,7 +570,7 @@ def __call__(
558570
for i, t in enumerate(self.progress_bar(timesteps)):
559571
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
560572
if len(alphas_cumprod) > 0:
561-
timestep_ratio = self.get_t_condioning(t.long().cpu(), alphas_cumprod)
573+
timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
562574
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
563575
else:
564576
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
@@ -609,6 +621,18 @@ def __call__(
609621
) # float() as bfloat16-> numpy doesnt work
610622

611623
if not return_dict:
612-
return (latents, prompt_embeds, negative_prompt_embeds)
624+
return (
625+
latents,
626+
prompt_embeds,
627+
prompt_embeds_pooled,
628+
negative_prompt_embeds,
629+
negative_prompt_embeds_pooled,
630+
)
613631

614-
return StableCascadePriorPipelineOutput(latents, prompt_embeds, negative_prompt_embeds)
632+
return StableCascadePriorPipelineOutput(
633+
image_embeddings=latents,
634+
prompt_embeds=prompt_embeds,
635+
prompt_embeds_pooled=prompt_embeds_pooled,
636+
negative_prompt_embeds=negative_prompt_embeds,
637+
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
638+
)

tests/pipelines/stable_cascade/test_stable_cascade_combined.py

+36-3
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,39 @@ def test_float16_inference(self):
241241
def test_callback_inputs(self):
242242
super().test_callback_inputs()
243243

244-
# def test_callback_cfg(self):
245-
# pass
246-
# pass
244+
def test_stable_cascade_combined_prompt_embeds(self):
245+
device = "cpu"
246+
components = self.get_dummy_components()
247+
248+
pipe = StableCascadeCombinedPipeline(**components)
249+
pipe.set_progress_bar_config(disable=None)
250+
251+
prompt = "A photograph of a shiba inu, wearing a hat"
252+
(
253+
prompt_embeds,
254+
prompt_embeds_pooled,
255+
negative_prompt_embeds,
256+
negative_prompt_embeds_pooled,
257+
) = pipe.prior_pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
258+
generator = torch.Generator(device=device)
259+
260+
output_prompt = pipe(
261+
prompt=prompt,
262+
num_inference_steps=1,
263+
prior_num_inference_steps=1,
264+
output_type="np",
265+
generator=generator.manual_seed(0),
266+
)
267+
output_prompt_embeds = pipe(
268+
prompt=None,
269+
prompt_embeds=prompt_embeds,
270+
prompt_embeds_pooled=prompt_embeds_pooled,
271+
negative_prompt_embeds=negative_prompt_embeds,
272+
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
273+
num_inference_steps=1,
274+
prior_num_inference_steps=1,
275+
output_type="np",
276+
generator=generator.manual_seed(0),
277+
)
278+
279+
assert np.abs(output_prompt.images - output_prompt_embeds.images).max() < 1e-5

tests/pipelines/stable_cascade/test_stable_cascade_decoder.py

+39
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,45 @@ def test_attention_slicing_forward_pass(self):
207207
def test_float16_inference(self):
208208
super().test_float16_inference()
209209

210+
def test_stable_cascade_decoder_prompt_embeds(self):
211+
device = "cpu"
212+
components = self.get_dummy_components()
213+
214+
pipe = StableCascadeDecoderPipeline(**components)
215+
pipe.set_progress_bar_config(disable=None)
216+
217+
inputs = self.get_dummy_inputs(device)
218+
image_embeddings = inputs["image_embeddings"]
219+
prompt = "A photograph of a shiba inu, wearing a hat"
220+
(
221+
prompt_embeds,
222+
prompt_embeds_pooled,
223+
negative_prompt_embeds,
224+
negative_prompt_embeds_pooled,
225+
) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
226+
generator = torch.Generator(device=device)
227+
228+
decoder_output_prompt = pipe(
229+
image_embeddings=image_embeddings,
230+
prompt=prompt,
231+
num_inference_steps=1,
232+
output_type="np",
233+
generator=generator.manual_seed(0),
234+
)
235+
decoder_output_prompt_embeds = pipe(
236+
image_embeddings=image_embeddings,
237+
prompt=None,
238+
prompt_embeds=prompt_embeds,
239+
prompt_embeds_pooled=prompt_embeds_pooled,
240+
negative_prompt_embeds=negative_prompt_embeds,
241+
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
242+
num_inference_steps=1,
243+
output_type="np",
244+
generator=generator.manual_seed(0),
245+
)
246+
247+
assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5
248+
210249

211250
@slow
212251
@require_torch_gpu

tests/pipelines/stable_cascade/test_stable_cascade_prior.py

+35
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,41 @@ def test_inference_with_prior_lora(self):
273273

274274
self.assertTrue(image_embed.shape == lora_image_embed.shape)
275275

276+
def test_stable_cascade_decoder_prompt_embeds(self):
277+
device = "cpu"
278+
components = self.get_dummy_components()
279+
280+
pipe = self.pipeline_class(**components)
281+
pipe.set_progress_bar_config(disable=None)
282+
283+
prompt = "A photograph of a shiba inu, wearing a hat"
284+
(
285+
prompt_embeds,
286+
prompt_embeds_pooled,
287+
negative_prompt_embeds,
288+
negative_prompt_embeds_pooled,
289+
) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
290+
generator = torch.Generator(device=device)
291+
292+
output_prompt = pipe(
293+
prompt=prompt,
294+
num_inference_steps=1,
295+
output_type="np",
296+
generator=generator.manual_seed(0),
297+
)
298+
output_prompt_embeds = pipe(
299+
prompt=None,
300+
prompt_embeds=prompt_embeds,
301+
prompt_embeds_pooled=prompt_embeds_pooled,
302+
negative_prompt_embeds=negative_prompt_embeds,
303+
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
304+
num_inference_steps=1,
305+
output_type="np",
306+
generator=generator.manual_seed(0),
307+
)
308+
309+
assert np.abs(output_prompt.image_embeddings - output_prompt_embeds.image_embeddings).max() < 1e-5
310+
276311

277312
@slow
278313
@require_torch_gpu

0 commit comments

Comments
 (0)