Skip to content

Commit 0e14cac

Browse files
hlkyDN6
authored andcommitted
Fix batch > 1 in HunyuanVideo (#10548)
1 parent 13ea83f commit 0e14cac

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

Diff for: src/diffusers/models/transformers/transformer_hunyuan_video.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,8 @@ def forward(
721721

722722
for i in range(batch_size):
723723
attention_mask[i, : effective_sequence_length[i]] = True
724-
attention_mask = attention_mask.unsqueeze(1) # [B, 1, N], for broadcasting across attention heads
724+
# [B, 1, 1, N], for broadcasting across attention heads
725+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
725726

726727
# 4. Transformer blocks
727728
if torch.is_grad_enabled() and self.gradient_checkpointing:

0 commit comments

Comments
 (0)