@@ -155,14 +155,14 @@ def __call__(
155
155
height : int = 512 ,
156
156
width : int = 512 ,
157
157
prior_num_inference_steps : int = 60 ,
158
- prior_timesteps : Optional [List [float ]] = None ,
159
158
prior_guidance_scale : float = 4.0 ,
160
159
num_inference_steps : int = 12 ,
161
- decoder_timesteps : Optional [List [float ]] = None ,
162
160
decoder_guidance_scale : float = 0.0 ,
163
161
negative_prompt : Optional [Union [str , List [str ]]] = None ,
164
162
prompt_embeds : Optional [torch .FloatTensor ] = None ,
163
+ prompt_embeds_pooled : Optional [torch .FloatTensor ] = None ,
165
164
negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
165
+ negative_prompt_embeds_pooled : Optional [torch .FloatTensor ] = None ,
166
166
num_images_per_prompt : int = 1 ,
167
167
generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
168
168
latents : Optional [torch .FloatTensor ] = None ,
@@ -187,10 +187,17 @@ def __call__(
187
187
prompt_embeds (`torch.FloatTensor`, *optional*):
188
188
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
189
189
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
190
+ prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
191
+ Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
192
+ weighting. If not provided, text embeddings will be generated from `prompt` input argument.
190
193
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
191
194
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
192
195
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
193
196
input argument.
197
+ negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
198
+ Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
199
+ prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
200
+ input argument.
194
201
num_images_per_prompt (`int`, *optional*, defaults to 1):
195
202
The number of images to generate per prompt.
196
203
height (`int`, *optional*, defaults to 512):
@@ -253,7 +260,6 @@ def __call__(
253
260
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
254
261
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
255
262
"""
256
-
257
263
prior_outputs = self .prior_pipe (
258
264
prompt = prompt if prompt_embeds is None else None ,
259
265
images = images ,
@@ -263,7 +269,9 @@ def __call__(
263
269
guidance_scale = prior_guidance_scale ,
264
270
negative_prompt = negative_prompt if negative_prompt_embeds is None else None ,
265
271
prompt_embeds = prompt_embeds ,
272
+ prompt_embeds_pooled = prompt_embeds_pooled ,
266
273
negative_prompt_embeds = negative_prompt_embeds ,
274
+ negative_prompt_embeds_pooled = negative_prompt_embeds_pooled ,
267
275
num_images_per_prompt = num_images_per_prompt ,
268
276
generator = generator ,
269
277
latents = latents ,
@@ -274,7 +282,9 @@ def __call__(
274
282
)
275
283
image_embeddings = prior_outputs .image_embeddings
276
284
prompt_embeds = prior_outputs .get ("prompt_embeds" , None )
285
+ prompt_embeds_pooled = prior_outputs .get ("prompt_embeds_pooled" , None )
277
286
negative_prompt_embeds = prior_outputs .get ("negative_prompt_embeds" , None )
287
+ negative_prompt_embeds_pooled = prior_outputs .get ("negative_prompt_embeds_pooled" , None )
278
288
279
289
outputs = self .decoder_pipe (
280
290
image_embeddings = image_embeddings ,
@@ -283,7 +293,9 @@ def __call__(
283
293
guidance_scale = decoder_guidance_scale ,
284
294
negative_prompt = negative_prompt if negative_prompt_embeds is None else None ,
285
295
prompt_embeds = prompt_embeds ,
296
+ prompt_embeds_pooled = prompt_embeds_pooled ,
286
297
negative_prompt_embeds = negative_prompt_embeds ,
298
+ negative_prompt_embeds_pooled = negative_prompt_embeds_pooled ,
287
299
generator = generator ,
288
300
output_type = output_type ,
289
301
return_dict = return_dict ,
0 commit comments