Skip to content

Commit b8cf84a

Browse files
maxin-cnsayakpaulyiyixuxua-r-r-o-w
authored
Latte: Latent Diffusion Transformer for Video Generation (#8404)
* add Latte to diffusers * remove print * remove print * remove print * remove unuse codes * remove layer_norm_latte and add a flag * remove layer_norm_latte and add a flag * update latte_pipeline * update latte_pipeline * remove unuse squeeze * add norm_hidden_states.ndim == 2: # for Latte * fixed test latte pipeline bugs * fixed test latte pipeline bugs * delete sh * add doc for latte * add licensing * Move Transformer3DModelOutput to modeling_outputs * give a default value to sample_size * remove the einops dependency * change norm2 for latte * modify pipeline of latte * update test for Latte * modify some codes for latte * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * modify for Latte pipeline * video_length -> num_frames; update prepare_latents copied from * make fix-copies * make style * typo: videe -> video * update * modify for Latte pipeline * modify latte pipeline * modify latte pipeline * modify latte pipeline * modify latte pipeline * modify for Latte pipeline * Delete .vscode directory * make style * make fix-copies * add latte transformer 3d to docs _toctree.yml * update example * reduce frames for test * fixed bug of _text_preprocessing * set num frame to 1 for testing * remove unuse print * add text = self._clean_caption(text) again --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Aryan <[email protected]> Co-authored-by: Aryan <[email protected]>
1 parent 673eb60 commit b8cf84a

15 files changed

+1617
-3
lines changed

Diff for: .gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,4 +175,4 @@ tags
175175
.ruff_cache
176176

177177
# wandb
178-
wandb
178+
wandb

Diff for: docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,8 @@
249249
title: DiTTransformer2DModel
250250
- local: api/models/hunyuan_transformer2d
251251
title: HunyuanDiT2DModel
252+
- local: api/models/latte_transformer3d
253+
title: LatteTransformer3DModel
252254
- local: api/models/lumina_nextdit2d
253255
title: LuminaNextDiT2DModel
254256
- local: api/models/transformer_temporal

Diff for: docs/source/en/api/models/latte_transformer3d.md

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
## LatteTransformer3DModel
14+
15+
A Diffusion Transformer model for 3D data from [Latte](https://github.com/Vchitect/Latte).
16+
17+
## LatteTransformer3DModel
18+
19+
[[autodoc]] LatteTransformer3DModel

Diff for: src/diffusers/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
"HunyuanDiT2DMultiControlNetModel",
8989
"I2VGenXLUNet",
9090
"Kandinsky3UNet",
91+
"LatteTransformer3DModel",
9192
"LuminaNextDiT2DModel",
9293
"ModelMixin",
9394
"MotionAdapter",
@@ -269,6 +270,7 @@
269270
"KandinskyV22PriorPipeline",
270271
"LatentConsistencyModelImg2ImgPipeline",
271272
"LatentConsistencyModelPipeline",
273+
"LattePipeline",
272274
"LDMTextToImagePipeline",
273275
"LEditsPPPipelineStableDiffusion",
274276
"LEditsPPPipelineStableDiffusionXL",
@@ -513,6 +515,7 @@
513515
HunyuanDiT2DMultiControlNetModel,
514516
I2VGenXLUNet,
515517
Kandinsky3UNet,
518+
LatteTransformer3DModel,
516519
LuminaNextDiT2DModel,
517520
ModelMixin,
518521
MotionAdapter,
@@ -672,6 +675,7 @@
672675
KandinskyV22PriorPipeline,
673676
LatentConsistencyModelImg2ImgPipeline,
674677
LatentConsistencyModelPipeline,
678+
LattePipeline,
675679
LDMTextToImagePipeline,
676680
LEditsPPPipelineStableDiffusion,
677681
LEditsPPPipelineStableDiffusionXL,

Diff for: src/diffusers/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
4242
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
4343
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
44+
_import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
4445
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
4546
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
4647
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
@@ -86,6 +87,7 @@
8687
DiTTransformer2DModel,
8788
DualTransformer2DModel,
8889
HunyuanDiT2DModel,
90+
LatteTransformer3DModel,
8991
LuminaNextDiT2DModel,
9092
PixArtTransformer2DModel,
9193
PriorTransformer,

Diff for: src/diffusers/models/attention.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,10 @@ def __init__(
359359
out_bias=attention_out_bias,
360360
) # is self-attn if encoder_hidden_states is none
361361
else:
362-
self.norm2 = None
362+
if norm_type == "ada_norm_single": # For Latte
363+
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
364+
else:
365+
self.norm2 = None
363366
self.attn2 = None
364367

365368
# 3. Feed-forward
@@ -439,7 +442,6 @@ def forward(
439442
).chunk(6, dim=1)
440443
norm_hidden_states = self.norm1(hidden_states)
441444
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
442-
norm_hidden_states = norm_hidden_states.squeeze(1)
443445
else:
444446
raise ValueError("Incorrect norm used")
445447

@@ -456,6 +458,7 @@ def forward(
456458
attention_mask=attention_mask,
457459
**cross_attention_kwargs,
458460
)
461+
459462
if self.norm_type == "ada_norm_zero":
460463
attn_output = gate_msa.unsqueeze(1) * attn_output
461464
elif self.norm_type == "ada_norm_single":

Diff for: src/diffusers/models/transformers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .dit_transformer_2d import DiTTransformer2DModel
66
from .dual_transformer_2d import DualTransformer2DModel
77
from .hunyuan_transformer_2d import HunyuanDiT2DModel
8+
from .latte_transformer_3d import LatteTransformer3DModel
89
from .lumina_nextdit2d import LuminaNextDiT2DModel
910
from .pixart_transformer_2d import PixArtTransformer2DModel
1011
from .prior_transformer import PriorTransformer

0 commit comments

Comments
 (0)