Skip to content

Commit 97e0ef4

Browse files
a-r-r-o-whlkygithub-actions[bot]
authored
Hidream refactoring follow ups (#11299)
* HiDream Image * update * -einops * py3.8 * fix -einops * mixins, offload_seq, option_components * docs * Apply style fixes * trigger tests * Apply suggestions from code review Co-authored-by: Aryan <[email protected]> * joint_attention_kwargs -> attention_kwargs, fixes * fast tests * -_init_weights * style tests * move reshape logic * update slice 😴 * supports_dduf * 🤷🏻‍♂️ * Update src/diffusers/models/transformers/transformer_hidream_image.py Co-authored-by: Aryan <[email protected]> * address review comments * update tests * doc updates * update * Update src/diffusers/models/transformers/transformer_hidream_image.py * Apply style fixes --------- Co-authored-by: hlky <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent ed41db8 commit 97e0ef4

File tree

1 file changed

+10
-28
lines changed

1 file changed

+10
-28
lines changed

src/diffusers/models/transformers/transformer_hidream_image.py

+10-28
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,7 @@ def __init__(
604604
):
605605
super().__init__()
606606
self.out_channels = out_channels or in_channels
607-
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
608-
self.llama_layers = llama_layers
607+
self.inner_dim = num_attention_heads * attention_head_dim
609608

610609
self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim)
611610
self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim)
@@ -621,13 +620,13 @@ def __init__(
621620
HiDreamBlock(
622621
HiDreamImageTransformerBlock(
623622
dim=self.inner_dim,
624-
num_attention_heads=self.config.num_attention_heads,
625-
attention_head_dim=self.config.attention_head_dim,
623+
num_attention_heads=num_attention_heads,
624+
attention_head_dim=attention_head_dim,
626625
num_routed_experts=num_routed_experts,
627626
num_activated_experts=num_activated_experts,
628627
)
629628
)
630-
for _ in range(self.config.num_layers)
629+
for _ in range(num_layers)
631630
]
632631
)
633632

@@ -636,42 +635,26 @@ def __init__(
636635
HiDreamBlock(
637636
HiDreamImageSingleTransformerBlock(
638637
dim=self.inner_dim,
639-
num_attention_heads=self.config.num_attention_heads,
640-
attention_head_dim=self.config.attention_head_dim,
638+
num_attention_heads=num_attention_heads,
639+
attention_head_dim=attention_head_dim,
641640
num_routed_experts=num_routed_experts,
642641
num_activated_experts=num_activated_experts,
643642
)
644643
)
645-
for _ in range(self.config.num_single_layers)
644+
for _ in range(num_single_layers)
646645
]
647646
)
648647

649648
self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels)
650649

651-
caption_channels = [
652-
caption_channels[1],
653-
] * (num_layers + num_single_layers) + [
654-
caption_channels[0],
655-
]
650+
caption_channels = [caption_channels[1]] * (num_layers + num_single_layers) + [caption_channels[0]]
656651
caption_projection = []
657652
for caption_channel in caption_channels:
658653
caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim))
659654
self.caption_projection = nn.ModuleList(caption_projection)
660655
self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
661656

662-
def expand_timesteps(self, timesteps, batch_size, device):
663-
if not torch.is_tensor(timesteps):
664-
is_mps = device.type == "mps"
665-
if isinstance(timesteps, float):
666-
dtype = torch.float32 if is_mps else torch.float64
667-
else:
668-
dtype = torch.int32 if is_mps else torch.int64
669-
timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
670-
elif len(timesteps.shape) == 0:
671-
timesteps = timesteps[None].to(device)
672-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
673-
timesteps = timesteps.expand(batch_size)
674-
return timesteps
657+
self.gradient_checkpointing = False
675658

676659
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
677660
if is_training:
@@ -773,7 +756,6 @@ def forward(
773756
hidden_states = out
774757

775758
# 0. time
776-
timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
777759
timesteps = self.t_embedder(timesteps, hidden_states_type)
778760
p_embedder = self.p_embedder(pooled_embeds)
779761
temb = timesteps + p_embedder
@@ -793,7 +775,7 @@ def forward(
793775

794776
T5_encoder_hidden_states = encoder_hidden_states[0]
795777
encoder_hidden_states = encoder_hidden_states[-1]
796-
encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
778+
encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers]
797779

798780
if self.caption_projection is not None:
799781
new_encoder_hidden_states = []

0 commit comments

Comments
 (0)