Skip to content

Commit 291db3e

Browse files
committed
Revert "[Flux] reduce explicit device transfers and typecasting in flux." (#9896)
Revert "[Flux] reduce explicit device transfers and typecasting in flux. (#9817)" This reverts commit 5588725.
1 parent dbea93c commit 291db3e

6 files changed

+17
-17
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def encode_prompt(
371371
unscale_lora_layers(self.text_encoder_2, lora_scale)
372372

373373
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
374-
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
374+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
375375

376376
return prompt_embeds, pooled_prompt_embeds, text_ids
377377

@@ -427,7 +427,7 @@ def check_inputs(
427427

428428
@staticmethod
429429
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
430-
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
430+
latent_image_ids = torch.zeros(height, width, 3)
431431
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
432432
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
433433

@@ -437,7 +437,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
437437
latent_image_id_height * latent_image_id_width, latent_image_id_channels
438438
)
439439

440-
return latent_image_ids
440+
return latent_image_ids.to(device=device, dtype=dtype)
441441

442442
@staticmethod
443443
def _pack_latents(latents, batch_size, num_channels_latents, height, width):

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def check_inputs(
452452
@staticmethod
453453
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
454454
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
455-
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
455+
latent_image_ids = torch.zeros(height, width, 3)
456456
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
457457
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
458458

@@ -462,7 +462,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
462462
latent_image_id_height * latent_image_id_width, latent_image_id_channels
463463
)
464464

465-
return latent_image_ids
465+
return latent_image_ids.to(device=device, dtype=dtype)
466466

467467
@staticmethod
468468
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def encode_prompt(
407407
unscale_lora_layers(self.text_encoder_2, lora_scale)
408408

409409
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
410-
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
410+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
411411

412412
return prompt_embeds, pooled_prompt_embeds, text_ids
413413

@@ -495,7 +495,7 @@ def check_inputs(
495495
@staticmethod
496496
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
497497
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
498-
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
498+
latent_image_ids = torch.zeros(height, width, 3)
499499
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
500500
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
501501

@@ -505,7 +505,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
505505
latent_image_id_height * latent_image_id_width, latent_image_id_channels
506506
)
507507

508-
return latent_image_ids
508+
return latent_image_ids.to(device=device, dtype=dtype)
509509

510510
@staticmethod
511511
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def encode_prompt(
417417
unscale_lora_layers(self.text_encoder_2, lora_scale)
418418

419419
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
420-
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
420+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
421421

422422
return prompt_embeds, pooled_prompt_embeds, text_ids
423423

@@ -522,7 +522,7 @@ def check_inputs(
522522
@staticmethod
523523
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
524524
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
525-
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
525+
latent_image_ids = torch.zeros(height, width, 3)
526526
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
527527
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
528528

@@ -532,7 +532,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
532532
latent_image_id_height * latent_image_id_width, latent_image_id_channels
533533
)
534534

535-
return latent_image_ids
535+
return latent_image_ids.to(device=device, dtype=dtype)
536536

537537
@staticmethod
538538
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents

src/diffusers/pipelines/flux/pipeline_flux_img2img.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def encode_prompt(
391391
unscale_lora_layers(self.text_encoder_2, lora_scale)
392392

393393
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
394-
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
394+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
395395

396396
return prompt_embeds, pooled_prompt_embeds, text_ids
397397

@@ -479,7 +479,7 @@ def check_inputs(
479479
@staticmethod
480480
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
481481
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
482-
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
482+
latent_image_ids = torch.zeros(height, width, 3)
483483
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
484484
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
485485

@@ -489,7 +489,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
489489
latent_image_id_height * latent_image_id_width, latent_image_id_channels
490490
)
491491

492-
return latent_image_ids
492+
return latent_image_ids.to(device=device, dtype=dtype)
493493

494494
@staticmethod
495495
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents

src/diffusers/pipelines/flux/pipeline_flux_inpaint.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def encode_prompt(
395395
unscale_lora_layers(self.text_encoder_2, lora_scale)
396396

397397
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
398-
text_ids = torch.zeros(prompt_embeds.shape[1], 3, dtype=dtype, device=device)
398+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
399399

400400
return prompt_embeds, pooled_prompt_embeds, text_ids
401401

@@ -500,7 +500,7 @@ def check_inputs(
500500
@staticmethod
501501
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
502502
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
503-
latent_image_ids = torch.zeros(height, width, 3, device=device, dtype=dtype)
503+
latent_image_ids = torch.zeros(height, width, 3)
504504
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
505505
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
506506

@@ -510,7 +510,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
510510
latent_image_id_height * latent_image_id_width, latent_image_id_channels
511511
)
512512

513-
return latent_image_ids
513+
return latent_image_ids.to(device=device, dtype=dtype)
514514

515515
@staticmethod
516516
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents

0 commit comments

Comments
 (0)