Skip to content

Commit 74cab5b

Browse files
DN6sayakpaula-r-r-o-w
committed
Fix Mochi Quality Issues (#10033)
* update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * Update src/diffusers/models/transformers/transformer_mochi.py Co-authored-by: Aryan <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Aryan <[email protected]>
1 parent b2e245c commit 74cab5b

File tree

7 files changed

+337
-159
lines changed

7 files changed

+337
-159
lines changed

src/diffusers/models/attention_processor.py

+172-89
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,177 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
906906
return self.processor(self, hidden_states)
907907

908908

909+
class MochiAttention(nn.Module):
910+
def __init__(
911+
self,
912+
query_dim: int,
913+
added_kv_proj_dim: int,
914+
processor: "MochiAttnProcessor2_0",
915+
heads: int = 8,
916+
dim_head: int = 64,
917+
dropout: float = 0.0,
918+
bias: bool = False,
919+
added_proj_bias: bool = True,
920+
out_dim: Optional[int] = None,
921+
out_context_dim: Optional[int] = None,
922+
out_bias: bool = True,
923+
context_pre_only: bool = False,
924+
eps: float = 1e-5,
925+
):
926+
super().__init__()
927+
from .normalization import MochiRMSNorm
928+
929+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
930+
self.out_dim = out_dim if out_dim is not None else query_dim
931+
self.out_context_dim = out_context_dim if out_context_dim else query_dim
932+
self.context_pre_only = context_pre_only
933+
934+
self.heads = out_dim // dim_head if out_dim is not None else heads
935+
936+
self.norm_q = MochiRMSNorm(dim_head, eps, True)
937+
self.norm_k = MochiRMSNorm(dim_head, eps, True)
938+
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
939+
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
940+
941+
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
942+
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
943+
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
944+
945+
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
946+
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
947+
if self.context_pre_only is not None:
948+
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
949+
950+
self.to_out = nn.ModuleList([])
951+
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
952+
self.to_out.append(nn.Dropout(dropout))
953+
954+
if not self.context_pre_only:
955+
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
956+
957+
self.processor = processor
958+
959+
def forward(
960+
self,
961+
hidden_states: torch.Tensor,
962+
encoder_hidden_states: Optional[torch.Tensor] = None,
963+
attention_mask: Optional[torch.Tensor] = None,
964+
**kwargs,
965+
):
966+
return self.processor(
967+
self,
968+
hidden_states,
969+
encoder_hidden_states=encoder_hidden_states,
970+
attention_mask=attention_mask,
971+
**kwargs,
972+
)
973+
974+
975+
class MochiAttnProcessor2_0:
976+
"""Attention processor used in Mochi."""
977+
978+
def __init__(self):
979+
if not hasattr(F, "scaled_dot_product_attention"):
980+
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
981+
982+
def __call__(
983+
self,
984+
attn: "MochiAttention",
985+
hidden_states: torch.Tensor,
986+
encoder_hidden_states: torch.Tensor,
987+
attention_mask: torch.Tensor,
988+
image_rotary_emb: Optional[torch.Tensor] = None,
989+
) -> torch.Tensor:
990+
query = attn.to_q(hidden_states)
991+
key = attn.to_k(hidden_states)
992+
value = attn.to_v(hidden_states)
993+
994+
query = query.unflatten(2, (attn.heads, -1))
995+
key = key.unflatten(2, (attn.heads, -1))
996+
value = value.unflatten(2, (attn.heads, -1))
997+
998+
if attn.norm_q is not None:
999+
query = attn.norm_q(query)
1000+
if attn.norm_k is not None:
1001+
key = attn.norm_k(key)
1002+
1003+
encoder_query = attn.add_q_proj(encoder_hidden_states)
1004+
encoder_key = attn.add_k_proj(encoder_hidden_states)
1005+
encoder_value = attn.add_v_proj(encoder_hidden_states)
1006+
1007+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
1008+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
1009+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
1010+
1011+
if attn.norm_added_q is not None:
1012+
encoder_query = attn.norm_added_q(encoder_query)
1013+
if attn.norm_added_k is not None:
1014+
encoder_key = attn.norm_added_k(encoder_key)
1015+
1016+
if image_rotary_emb is not None:
1017+
1018+
def apply_rotary_emb(x, freqs_cos, freqs_sin):
1019+
x_even = x[..., 0::2].float()
1020+
x_odd = x[..., 1::2].float()
1021+
1022+
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
1023+
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
1024+
1025+
return torch.stack([cos, sin], dim=-1).flatten(-2)
1026+
1027+
query = apply_rotary_emb(query, *image_rotary_emb)
1028+
key = apply_rotary_emb(key, *image_rotary_emb)
1029+
1030+
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
1031+
encoder_query, encoder_key, encoder_value = (
1032+
encoder_query.transpose(1, 2),
1033+
encoder_key.transpose(1, 2),
1034+
encoder_value.transpose(1, 2),
1035+
)
1036+
1037+
sequence_length = query.size(2)
1038+
encoder_sequence_length = encoder_query.size(2)
1039+
total_length = sequence_length + encoder_sequence_length
1040+
1041+
batch_size, heads, _, dim = query.shape
1042+
attn_outputs = []
1043+
for idx in range(batch_size):
1044+
mask = attention_mask[idx][None, :]
1045+
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
1046+
1047+
valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
1048+
valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
1049+
valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
1050+
1051+
valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
1052+
valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
1053+
valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
1054+
1055+
attn_output = F.scaled_dot_product_attention(
1056+
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
1057+
)
1058+
valid_sequence_length = attn_output.size(2)
1059+
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
1060+
attn_outputs.append(attn_output)
1061+
1062+
hidden_states = torch.cat(attn_outputs, dim=0)
1063+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
1064+
1065+
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
1066+
(sequence_length, encoder_sequence_length), dim=1
1067+
)
1068+
1069+
# linear proj
1070+
hidden_states = attn.to_out[0](hidden_states)
1071+
# dropout
1072+
hidden_states = attn.to_out[1](hidden_states)
1073+
1074+
if hasattr(attn, "to_add_out"):
1075+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1076+
1077+
return hidden_states, encoder_hidden_states
1078+
1079+
9091080
class AttnProcessor:
9101081
r"""
9111082
Default processor for performing attention-related computations.
@@ -3868,94 +4039,6 @@ def __call__(
38684039
return hidden_states
38694040

38704041

3871-
class MochiAttnProcessor2_0:
3872-
"""Attention processor used in Mochi."""
3873-
3874-
def __init__(self):
3875-
if not hasattr(F, "scaled_dot_product_attention"):
3876-
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
3877-
3878-
def __call__(
3879-
self,
3880-
attn: Attention,
3881-
hidden_states: torch.Tensor,
3882-
encoder_hidden_states: torch.Tensor,
3883-
attention_mask: Optional[torch.Tensor] = None,
3884-
image_rotary_emb: Optional[torch.Tensor] = None,
3885-
) -> torch.Tensor:
3886-
query = attn.to_q(hidden_states)
3887-
key = attn.to_k(hidden_states)
3888-
value = attn.to_v(hidden_states)
3889-
3890-
query = query.unflatten(2, (attn.heads, -1))
3891-
key = key.unflatten(2, (attn.heads, -1))
3892-
value = value.unflatten(2, (attn.heads, -1))
3893-
3894-
if attn.norm_q is not None:
3895-
query = attn.norm_q(query)
3896-
if attn.norm_k is not None:
3897-
key = attn.norm_k(key)
3898-
3899-
encoder_query = attn.add_q_proj(encoder_hidden_states)
3900-
encoder_key = attn.add_k_proj(encoder_hidden_states)
3901-
encoder_value = attn.add_v_proj(encoder_hidden_states)
3902-
3903-
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
3904-
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
3905-
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
3906-
3907-
if attn.norm_added_q is not None:
3908-
encoder_query = attn.norm_added_q(encoder_query)
3909-
if attn.norm_added_k is not None:
3910-
encoder_key = attn.norm_added_k(encoder_key)
3911-
3912-
if image_rotary_emb is not None:
3913-
3914-
def apply_rotary_emb(x, freqs_cos, freqs_sin):
3915-
x_even = x[..., 0::2].float()
3916-
x_odd = x[..., 1::2].float()
3917-
3918-
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
3919-
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
3920-
3921-
return torch.stack([cos, sin], dim=-1).flatten(-2)
3922-
3923-
query = apply_rotary_emb(query, *image_rotary_emb)
3924-
key = apply_rotary_emb(key, *image_rotary_emb)
3925-
3926-
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
3927-
encoder_query, encoder_key, encoder_value = (
3928-
encoder_query.transpose(1, 2),
3929-
encoder_key.transpose(1, 2),
3930-
encoder_value.transpose(1, 2),
3931-
)
3932-
3933-
sequence_length = query.size(2)
3934-
encoder_sequence_length = encoder_query.size(2)
3935-
3936-
query = torch.cat([query, encoder_query], dim=2)
3937-
key = torch.cat([key, encoder_key], dim=2)
3938-
value = torch.cat([value, encoder_value], dim=2)
3939-
3940-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
3941-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
3942-
hidden_states = hidden_states.to(query.dtype)
3943-
3944-
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
3945-
(sequence_length, encoder_sequence_length), dim=1
3946-
)
3947-
3948-
# linear proj
3949-
hidden_states = attn.to_out[0](hidden_states)
3950-
# dropout
3951-
hidden_states = attn.to_out[1](hidden_states)
3952-
3953-
if getattr(attn, "to_add_out", None) is not None:
3954-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
3955-
3956-
return hidden_states, encoder_hidden_states
3957-
3958-
39594042
class FusedAttnProcessor2_0:
39604043
r"""
39614044
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
@@ -5668,13 +5751,13 @@ def __call__(
56685751
AttnProcessorNPU,
56695752
AttnProcessor2_0,
56705753
MochiVaeAttnProcessor2_0,
5754+
MochiAttnProcessor2_0,
56715755
StableAudioAttnProcessor2_0,
56725756
HunyuanAttnProcessor2_0,
56735757
FusedHunyuanAttnProcessor2_0,
56745758
PAGHunyuanAttnProcessor2_0,
56755759
PAGCFGHunyuanAttnProcessor2_0,
56765760
LuminaAttnProcessor2_0,
5677-
MochiAttnProcessor2_0,
56785761
FusedAttnProcessor2_0,
56795762
CustomDiffusionXFormersAttnProcessor,
56805763
CustomDiffusionAttnProcessor2_0,

src/diffusers/models/embeddings.py

-1
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,6 @@ def forward(self, latent):
542542
height, width = latent.shape[-2:]
543543
else:
544544
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
545-
546545
latent = self.proj(latent)
547546
if self.flatten:
548547
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC

src/diffusers/models/normalization.py

+30-27
Original file line numberDiff line numberDiff line change
@@ -234,33 +234,6 @@ def forward(
234234
return x, gate_msa, scale_mlp, gate_mlp
235235

236236

237-
class MochiRMSNormZero(nn.Module):
238-
r"""
239-
Adaptive RMS Norm used in Mochi.
240-
241-
Parameters:
242-
embedding_dim (`int`): The size of each embedding vector.
243-
"""
244-
245-
def __init__(
246-
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
247-
) -> None:
248-
super().__init__()
249-
250-
self.silu = nn.SiLU()
251-
self.linear = nn.Linear(embedding_dim, hidden_dim)
252-
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
253-
254-
def forward(
255-
self, hidden_states: torch.Tensor, emb: torch.Tensor
256-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
257-
emb = self.linear(self.silu(emb))
258-
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
259-
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])
260-
261-
return hidden_states, gate_msa, scale_mlp, gate_mlp
262-
263-
264237
class AdaLayerNormSingle(nn.Module):
265238
r"""
266239
Norm layer adaptive layer norm single (adaLN-single).
@@ -549,6 +522,36 @@ def forward(self, hidden_states):
549522
return hidden_states
550523

551524

525+
# TODO: (Dhruv) This can be replaced with regular RMSNorm in Mochi once `_keep_in_fp32_modules` is supported
526+
# for sharded checkpoints, see: https://github.com/huggingface/diffusers/issues/10013
527+
class MochiRMSNorm(nn.Module):
528+
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
529+
super().__init__()
530+
531+
self.eps = eps
532+
533+
if isinstance(dim, numbers.Integral):
534+
dim = (dim,)
535+
536+
self.dim = torch.Size(dim)
537+
538+
if elementwise_affine:
539+
self.weight = nn.Parameter(torch.ones(dim))
540+
else:
541+
self.weight = None
542+
543+
def forward(self, hidden_states):
544+
input_dtype = hidden_states.dtype
545+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
546+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
547+
548+
if self.weight is not None:
549+
hidden_states = hidden_states * self.weight
550+
hidden_states = hidden_states.to(input_dtype)
551+
552+
return hidden_states
553+
554+
552555
class GlobalResponseNorm(nn.Module):
553556
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
554557
def __init__(self, dim):

0 commit comments

Comments
 (0)