Skip to content

Fix passing pooled prompt embeds to Cascade Decoder and Combined Pipeline #7287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,9 @@ def __call__(
guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
Expand Down Expand Up @@ -321,10 +323,17 @@ def __call__(
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input
argument.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
Expand Down Expand Up @@ -378,18 +387,24 @@ def __call__(

# 2. Encode caption
if prompt_embeds is None and negative_prompt_embeds is None:
prompt_embeds, _, negative_prompt_embeds, _ = self.encode_prompt(
_, prompt_embeds_pooled, _, negative_prompt_embeds_pooled = self.encode_prompt(
prompt=prompt,
device=device,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
)

# The pooled embeds from the prior are pooled again before being passed to the decoder
prompt_embeds_pooled = (
torch.cat([prompt_embeds, negative_prompt_embeds]) if self.do_classifier_free_guidance else prompt_embeds
torch.cat([prompt_embeds_pooled, negative_prompt_embeds_pooled])
if self.do_classifier_free_guidance
else prompt_embeds_pooled
)
effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,14 @@ def __call__(
height: int = 512,
width: int = 512,
prior_num_inference_steps: int = 60,
prior_timesteps: Optional[List[float]] = None,
prior_guidance_scale: float = 4.0,
num_inference_steps: int = 12,
decoder_timesteps: Optional[List[float]] = None,
decoder_guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_pooled: Optional[torch.FloatTensor] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
Expand All @@ -187,10 +187,17 @@ def __call__(
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
input argument.
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
input argument.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to 512):
Expand Down Expand Up @@ -253,7 +260,6 @@ def __call__(
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
"""

prior_outputs = self.prior_pipe(
prompt=prompt if prompt_embeds is None else None,
images=images,
Expand All @@ -263,7 +269,9 @@ def __call__(
guidance_scale=prior_guidance_scale,
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
latents=latents,
Expand All @@ -274,7 +282,9 @@ def __call__(
)
image_embeddings = prior_outputs.image_embeddings
prompt_embeds = prior_outputs.get("prompt_embeds", None)
prompt_embeds_pooled = prior_outputs.get("prompt_embeds_pooled", None)
negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None)
negative_prompt_embeds_pooled = prior_outputs.get("negative_prompt_embeds_pooled", None)

outputs = self.decoder_pipe(
image_embeddings=image_embeddings,
Expand All @@ -283,7 +293,9 @@ def __call__(
guidance_scale=decoder_guidance_scale,
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
generator=generator,
output_type=output_type,
return_dict=return_dict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ class StableCascadePriorPipelineOutput(BaseOutput):

image_embeddings: Union[torch.FloatTensor, np.ndarray]
prompt_embeds: Union[torch.FloatTensor, np.ndarray]
prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray]
negative_prompt_embeds: Union[torch.FloatTensor, np.ndarray]
negative_prompt_embeds_pooled: Union[torch.FloatTensor, np.ndarray]


class StableCascadePriorPipeline(DiffusionPipeline):
Expand Down Expand Up @@ -305,6 +307,16 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

if prompt_embeds is not None and prompt_embeds_pooled is None:
raise ValueError(
"If `prompt_embeds` are provided, `prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`"
)

if negative_prompt_embeds is not None and negative_prompt_embeds_pooled is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_prompt_embeds_pooled` must also be provided. Make sure to generate `prompt_embeds_pooled` from the same text encoder that was used to generate `prompt_embeds`"
)

if prompt_embeds_pooled is not None and negative_prompt_embeds_pooled is not None:
if prompt_embeds_pooled.shape != negative_prompt_embeds_pooled.shape:
raise ValueError(
Expand Down Expand Up @@ -339,7 +351,7 @@ def do_classifier_free_guidance(self):
def num_timesteps(self):
return self._num_timesteps

def get_t_condioning(self, t, alphas_cumprod):
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
s = torch.tensor([0.003])
clamp_range = [0, 1]
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
Expand Down Expand Up @@ -558,7 +570,7 @@ def __call__(
for i, t in enumerate(self.progress_bar(timesteps)):
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
if len(alphas_cumprod) > 0:
timestep_ratio = self.get_t_condioning(t.long().cpu(), alphas_cumprod)
timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
else:
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
Expand Down Expand Up @@ -609,6 +621,18 @@ def __call__(
) # float() as bfloat16-> numpy doesnt work

if not return_dict:
return (latents, prompt_embeds, negative_prompt_embeds)
return (
latents,
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
)

return StableCascadePriorPipelineOutput(latents, prompt_embeds, negative_prompt_embeds)
return StableCascadePriorPipelineOutput(
image_embeddings=latents,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
)
39 changes: 36 additions & 3 deletions tests/pipelines/stable_cascade/test_stable_cascade_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,39 @@ def test_float16_inference(self):
def test_callback_inputs(self):
super().test_callback_inputs()

# def test_callback_cfg(self):
# pass
# pass
def test_stable_cascade_combined_prompt_embeds(self):
device = "cpu"
components = self.get_dummy_components()

pipe = StableCascadeCombinedPipeline(**components)
pipe.set_progress_bar_config(disable=None)

prompt = "A photograph of a shiba inu, wearing a hat"
(
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
) = pipe.prior_pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
generator = torch.Generator(device=device)

output_prompt = pipe(
prompt=prompt,
num_inference_steps=1,
prior_num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
output_prompt_embeds = pipe(
prompt=None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_inference_steps=1,
prior_num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)

assert np.abs(output_prompt.images - output_prompt_embeds.images).max() < 1e-5
39 changes: 39 additions & 0 deletions tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,45 @@ def test_attention_slicing_forward_pass(self):
def test_float16_inference(self):
super().test_float16_inference()

def test_stable_cascade_decoder_prompt_embeds(self):
device = "cpu"
components = self.get_dummy_components()

pipe = StableCascadeDecoderPipeline(**components)
pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
image_embeddings = inputs["image_embeddings"]
prompt = "A photograph of a shiba inu, wearing a hat"
(
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
generator = torch.Generator(device=device)

decoder_output_prompt = pipe(
image_embeddings=image_embeddings,
prompt=prompt,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
decoder_output_prompt_embeds = pipe(
image_embeddings=image_embeddings,
prompt=None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)

assert np.abs(decoder_output_prompt.images - decoder_output_prompt_embeds.images).max() < 1e-5


@slow
@require_torch_gpu
Expand Down
35 changes: 35 additions & 0 deletions tests/pipelines/stable_cascade/test_stable_cascade_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,41 @@ def test_inference_with_prior_lora(self):

self.assertTrue(image_embed.shape == lora_image_embed.shape)

def test_stable_cascade_decoder_prompt_embeds(self):
device = "cpu"
components = self.get_dummy_components()

pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)

prompt = "A photograph of a shiba inu, wearing a hat"
(
prompt_embeds,
prompt_embeds_pooled,
negative_prompt_embeds,
negative_prompt_embeds_pooled,
) = pipe.encode_prompt(device, 1, 1, False, prompt=prompt)
generator = torch.Generator(device=device)

output_prompt = pipe(
prompt=prompt,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)
output_prompt_embeds = pipe(
prompt=None,
prompt_embeds=prompt_embeds,
prompt_embeds_pooled=prompt_embeds_pooled,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_pooled=negative_prompt_embeds_pooled,
num_inference_steps=1,
output_type="np",
generator=generator.manual_seed(0),
)

assert np.abs(output_prompt.image_embeddings - output_prompt_embeds.image_embeddings).max() < 1e-5


@slow
@require_torch_gpu
Expand Down