@@ -604,8 +604,7 @@ def __init__(
604
604
):
605
605
super ().__init__ ()
606
606
self .out_channels = out_channels or in_channels
607
- self .inner_dim = self .config .num_attention_heads * self .config .attention_head_dim
608
- self .llama_layers = llama_layers
607
+ self .inner_dim = num_attention_heads * attention_head_dim
609
608
610
609
self .t_embedder = HiDreamImageTimestepEmbed (self .inner_dim )
611
610
self .p_embedder = HiDreamImagePooledEmbed (text_emb_dim , self .inner_dim )
@@ -621,13 +620,13 @@ def __init__(
621
620
HiDreamBlock (
622
621
HiDreamImageTransformerBlock (
623
622
dim = self .inner_dim ,
624
- num_attention_heads = self . config . num_attention_heads ,
625
- attention_head_dim = self . config . attention_head_dim ,
623
+ num_attention_heads = num_attention_heads ,
624
+ attention_head_dim = attention_head_dim ,
626
625
num_routed_experts = num_routed_experts ,
627
626
num_activated_experts = num_activated_experts ,
628
627
)
629
628
)
630
- for _ in range (self . config . num_layers )
629
+ for _ in range (num_layers )
631
630
]
632
631
)
633
632
@@ -636,42 +635,26 @@ def __init__(
636
635
HiDreamBlock (
637
636
HiDreamImageSingleTransformerBlock (
638
637
dim = self .inner_dim ,
639
- num_attention_heads = self . config . num_attention_heads ,
640
- attention_head_dim = self . config . attention_head_dim ,
638
+ num_attention_heads = num_attention_heads ,
639
+ attention_head_dim = attention_head_dim ,
641
640
num_routed_experts = num_routed_experts ,
642
641
num_activated_experts = num_activated_experts ,
643
642
)
644
643
)
645
- for _ in range (self . config . num_single_layers )
644
+ for _ in range (num_single_layers )
646
645
]
647
646
)
648
647
649
648
self .final_layer = HiDreamImageOutEmbed (self .inner_dim , patch_size , self .out_channels )
650
649
651
- caption_channels = [
652
- caption_channels [1 ],
653
- ] * (num_layers + num_single_layers ) + [
654
- caption_channels [0 ],
655
- ]
650
+ caption_channels = [caption_channels [1 ]] * (num_layers + num_single_layers ) + [caption_channels [0 ]]
656
651
caption_projection = []
657
652
for caption_channel in caption_channels :
658
653
caption_projection .append (TextProjection (in_features = caption_channel , hidden_size = self .inner_dim ))
659
654
self .caption_projection = nn .ModuleList (caption_projection )
660
655
self .max_seq = max_resolution [0 ] * max_resolution [1 ] // (patch_size * patch_size )
661
656
662
- def expand_timesteps (self , timesteps , batch_size , device ):
663
- if not torch .is_tensor (timesteps ):
664
- is_mps = device .type == "mps"
665
- if isinstance (timesteps , float ):
666
- dtype = torch .float32 if is_mps else torch .float64
667
- else :
668
- dtype = torch .int32 if is_mps else torch .int64
669
- timesteps = torch .tensor ([timesteps ], dtype = dtype , device = device )
670
- elif len (timesteps .shape ) == 0 :
671
- timesteps = timesteps [None ].to (device )
672
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
673
- timesteps = timesteps .expand (batch_size )
674
- return timesteps
657
+ self .gradient_checkpointing = False
675
658
676
659
def unpatchify (self , x : torch .Tensor , img_sizes : List [Tuple [int , int ]], is_training : bool ) -> List [torch .Tensor ]:
677
660
if is_training :
@@ -773,7 +756,6 @@ def forward(
773
756
hidden_states = out
774
757
775
758
# 0. time
776
- timesteps = self .expand_timesteps (timesteps , batch_size , hidden_states .device )
777
759
timesteps = self .t_embedder (timesteps , hidden_states_type )
778
760
p_embedder = self .p_embedder (pooled_embeds )
779
761
temb = timesteps + p_embedder
@@ -793,7 +775,7 @@ def forward(
793
775
794
776
T5_encoder_hidden_states = encoder_hidden_states [0 ]
795
777
encoder_hidden_states = encoder_hidden_states [- 1 ]
796
- encoder_hidden_states = [encoder_hidden_states [k ] for k in self .llama_layers ]
778
+ encoder_hidden_states = [encoder_hidden_states [k ] for k in self .config . llama_layers ]
797
779
798
780
if self .caption_projection is not None :
799
781
new_encoder_hidden_states = []
0 commit comments