@@ -371,7 +371,7 @@ def encode_prompt(
371
371
unscale_lora_layers (self .text_encoder_2 , lora_scale )
372
372
373
373
dtype = self .text_encoder .dtype if self .text_encoder is not None else self .transformer .dtype
374
- text_ids = torch .zeros (prompt_embeds .shape [1 ], 3 , dtype = dtype , device = device )
374
+ text_ids = torch .zeros (prompt_embeds .shape [1 ], 3 ). to ( device = device , dtype = dtype )
375
375
376
376
return prompt_embeds , pooled_prompt_embeds , text_ids
377
377
@@ -427,7 +427,7 @@ def check_inputs(
427
427
428
428
@staticmethod
429
429
def _prepare_latent_image_ids (batch_size , height , width , device , dtype ):
430
- latent_image_ids = torch .zeros (height , width , 3 , device = device , dtype = dtype )
430
+ latent_image_ids = torch .zeros (height , width , 3 )
431
431
latent_image_ids [..., 1 ] = latent_image_ids [..., 1 ] + torch .arange (height )[:, None ]
432
432
latent_image_ids [..., 2 ] = latent_image_ids [..., 2 ] + torch .arange (width )[None , :]
433
433
@@ -437,7 +437,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
437
437
latent_image_id_height * latent_image_id_width , latent_image_id_channels
438
438
)
439
439
440
- return latent_image_ids
440
+ return latent_image_ids . to ( device = device , dtype = dtype )
441
441
442
442
@staticmethod
443
443
def _pack_latents (latents , batch_size , num_channels_latents , height , width ):
0 commit comments