@@ -205,6 +205,110 @@ def disable_attention_slicing(self):
205
205
# set slice_size = `None` to disable `set_attention_slice`
206
206
self .enable_attention_slicing (None )
207
207
208
+ @property
209
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
210
+ def _execution_device (self ):
211
+ r"""
212
+ Returns the device on which the pipeline's models will be executed. After calling
213
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
214
+ hooks.
215
+ """
216
+ if self .device != torch .device ("meta" ) or not hasattr (self .unet , "_hf_hook" ):
217
+ return self .device
218
+ for module in self .unet .modules ():
219
+ if (
220
+ hasattr (module , "_hf_hook" )
221
+ and hasattr (module ._hf_hook , "execution_device" )
222
+ and module ._hf_hook .execution_device is not None
223
+ ):
224
+ return torch .device (module ._hf_hook .execution_device )
225
+ return self .device
226
+
227
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
228
+ def _encode_prompt (self , prompt , device , num_images_per_prompt , do_classifier_free_guidance , negative_prompt ):
229
+ r"""
230
+ Encodes the prompt into text encoder hidden states.
231
+
232
+ Args:
233
+ prompt (`str` or `list(int)`):
234
+ prompt to be encoded
235
+ device: (`torch.device`):
236
+ torch device
237
+ num_images_per_prompt (`int`):
238
+ number of images that should be generated per prompt
239
+ do_classifier_free_guidance (`bool`):
240
+ whether to use classifier free guidance or not
241
+ negative_prompt (`str` or `List[str]`):
242
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
243
+ if `guidance_scale` is less than `1`).
244
+ """
245
+ batch_size = len (prompt ) if isinstance (prompt , list ) else 1
246
+
247
+ text_inputs = self .tokenizer (
248
+ prompt ,
249
+ padding = "max_length" ,
250
+ max_length = self .tokenizer .model_max_length ,
251
+ return_tensors = "pt" ,
252
+ )
253
+ text_input_ids = text_inputs .input_ids
254
+
255
+ if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
256
+ removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
257
+ logger .warning (
258
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
259
+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
260
+ )
261
+ text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
262
+ text_embeddings = self .text_encoder (text_input_ids .to (device ))[0 ]
263
+
264
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
265
+ bs_embed , seq_len , _ = text_embeddings .shape
266
+ text_embeddings = text_embeddings .repeat (1 , num_images_per_prompt , 1 )
267
+ text_embeddings = text_embeddings .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
268
+
269
+ # get unconditional embeddings for classifier free guidance
270
+ if do_classifier_free_guidance :
271
+ uncond_tokens : List [str ]
272
+ if negative_prompt is None :
273
+ uncond_tokens = ["" ] * batch_size
274
+ elif type (prompt ) is not type (negative_prompt ):
275
+ raise TypeError (
276
+ f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
277
+ f" { type (prompt )} ."
278
+ )
279
+ elif isinstance (negative_prompt , str ):
280
+ uncond_tokens = [negative_prompt ]
281
+ elif batch_size != len (negative_prompt ):
282
+ raise ValueError (
283
+ f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
284
+ f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches"
285
+ " the batch size of `prompt`."
286
+ )
287
+ else :
288
+ uncond_tokens = negative_prompt
289
+
290
+ max_length = text_input_ids .shape [- 1 ]
291
+ uncond_input = self .tokenizer (
292
+ uncond_tokens ,
293
+ padding = "max_length" ,
294
+ max_length = max_length ,
295
+ truncation = True ,
296
+ return_tensors = "pt" ,
297
+ )
298
+ uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (device ))[0 ]
299
+
300
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
301
+ seq_len = uncond_embeddings .shape [1 ]
302
+ uncond_embeddings = uncond_embeddings .repeat (1 , num_images_per_prompt , 1 )
303
+ uncond_embeddings = uncond_embeddings .view (batch_size * num_images_per_prompt , seq_len , - 1 )
304
+
305
+ # For classifier free guidance, we need to do two forward passes.
306
+ # Here we concatenate the unconditional and text embeddings into a single batch
307
+ # to avoid doing two forward passes
308
+ text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
309
+
310
+ return text_embeddings
311
+
208
312
@torch .no_grad ()
209
313
def __call__ (
210
314
self ,
@@ -309,89 +413,17 @@ def __call__(
309
413
if isinstance (init_image , PIL .Image .Image ):
310
414
init_image = preprocess (init_image )
311
415
312
- # get prompt text embeddings
313
- text_inputs = self .tokenizer (
314
- prompt ,
315
- padding = "max_length" ,
316
- max_length = self .tokenizer .model_max_length ,
317
- return_tensors = "pt" ,
318
- )
319
- source_text_inputs = self .tokenizer (
320
- source_prompt ,
321
- padding = "max_length" ,
322
- max_length = self .tokenizer .model_max_length ,
323
- return_tensors = "pt" ,
324
- )
325
- text_input_ids = text_inputs .input_ids
326
- source_text_input_ids = source_text_inputs .input_ids
327
-
328
- if text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
329
- removed_text = self .tokenizer .batch_decode (text_input_ids [:, self .tokenizer .model_max_length :])
330
- logger .warning (
331
- "The following part of your input was truncated because CLIP can only handle sequences up to"
332
- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
333
- )
334
- text_input_ids = text_input_ids [:, : self .tokenizer .model_max_length ]
335
- if source_text_input_ids .shape [- 1 ] > self .tokenizer .model_max_length :
336
- removed_text = self .tokenizer .batch_decode (source_text_input_ids [:, self .tokenizer .model_max_length :])
337
- logger .warning (
338
- "The following part of your input was truncated because CLIP can only handle sequences up to"
339
- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
340
- )
341
- source_text_input_ids = source_text_input_ids [:, : self .tokenizer .model_max_length ]
342
- text_embeddings = self .text_encoder (text_input_ids .to (self .device ))[0 ]
343
- source_text_embeddings = self .text_encoder (source_text_input_ids .to (self .device ))[0 ]
344
-
345
- # duplicate text embeddings for each generation per prompt
346
- text_embeddings = text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
347
- source_text_embeddings = source_text_embeddings .repeat_interleave (num_images_per_prompt , dim = 0 )
416
+ device = self ._execution_device
348
417
349
418
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
350
419
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
351
420
# corresponds to doing no classifier free guidance.
421
+ do_classifier_free_guidance = guidance_scale > 1.0
352
422
353
- # get unconditional embeddings for classifier free guidance
354
- uncond_tokens = ["" ]
355
-
356
- max_length = text_input_ids .shape [- 1 ]
357
- uncond_input = self .tokenizer (
358
- uncond_tokens ,
359
- padding = "max_length" ,
360
- max_length = max_length ,
361
- truncation = True ,
362
- return_tensors = "pt" ,
423
+ text_embeddings = self ._encode_prompt (prompt , device , num_images_per_prompt , do_classifier_free_guidance , None )
424
+ source_text_embeddings = self ._encode_prompt (
425
+ source_prompt , device , num_images_per_prompt , do_classifier_free_guidance , None
363
426
)
364
- uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self .device ))[0 ]
365
-
366
- # duplicate unconditional embeddings for each generation per prompt
367
- uncond_embeddings = uncond_embeddings .repeat_interleave (batch_size * num_images_per_prompt , dim = 0 )
368
-
369
- # For classifier free guidance, we need to do two forward passes.
370
- # Here we concatenate the unconditional and text embeddings into a single batch
371
- # to avoid doing two forward passes
372
- text_embeddings = torch .cat ([uncond_embeddings , text_embeddings ])
373
-
374
- source_uncond_tokens = ["" ]
375
-
376
- max_length = source_text_input_ids .shape [- 1 ]
377
- source_uncond_input = self .tokenizer (
378
- source_uncond_tokens ,
379
- padding = "max_length" ,
380
- max_length = max_length ,
381
- truncation = True ,
382
- return_tensors = "pt" ,
383
- )
384
- source_uncond_embeddings = self .text_encoder (source_uncond_input .input_ids .to (self .device ))[0 ]
385
-
386
- # duplicate unconditional embeddings for each generation per prompt
387
- source_uncond_embeddings = source_uncond_embeddings .repeat_interleave (
388
- batch_size * num_images_per_prompt , dim = 0
389
- )
390
-
391
- # For classifier free guidance, we need to do two forward passes.
392
- # Here we concatenate the unconditional and text embeddings into a single batch
393
- # to avoid doing two forward passes
394
- source_text_embeddings = torch .cat ([source_uncond_embeddings , source_text_embeddings ])
395
427
396
428
# encode the init image into latents and scale the latents
397
429
latents_dtype = text_embeddings .dtype
0 commit comments