Skip to content

Commit 81440fd

Browse files
zhuole1025csuhanyiyixuxua-r-r-o-whlky
authored
Add support for lumina2 (#10642)
* Add support for lumina2 --------- Co-authored-by: csuhan <[email protected]> Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Aryan <[email protected]> Co-authored-by: hlky <[email protected]>
1 parent c470274 commit 81440fd

19 files changed

+1725
-4
lines changed

docs/source/en/_toctree.yml

+4
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@
290290
title: LatteTransformer3DModel
291291
- local: api/models/lumina_nextdit2d
292292
title: LuminaNextDiT2DModel
293+
- local: api/models/lumina2_transformer2d
294+
title: Lumina2Transformer2DModel
293295
- local: api/models/ltx_video_transformer3d
294296
title: LTXVideoTransformer3DModel
295297
- local: api/models/mochi_transformer3d
@@ -442,6 +444,8 @@
442444
title: LEDITS++
443445
- local: api/pipelines/ltx_video
444446
title: LTXVideo
447+
- local: api/pipelines/lumina2
448+
title: Lumina 2.0
445449
- local: api/pipelines/lumina
446450
title: Lumina-T2X
447451
- local: api/pipelines/marigold
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# Lumina2Transformer2DModel
13+
14+
A Diffusion Transformer model for 3D video-like data was introduced in [Lumina Image 2.0](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) by Alpha-VLLM.
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import Lumina2Transformer2DModel
20+
21+
transformer = Lumina2Transformer2DModel.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", subfolder="transformer", torch_dtype=torch.bfloat16)
22+
```
23+
24+
## Lumina2Transformer2DModel
25+
26+
[[autodoc]] Lumina2Transformer2DModel
27+
28+
## Transformer2DModelOutput
29+
30+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License. -->
14+
15+
# Lumina2
16+
17+
[Lumina Image 2.0: A Unified and Efficient Image Generative Model](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) is a 2 billion parameter flow-based diffusion transformer capable of generating diverse images from text descriptions.
18+
19+
The abstract from the paper is:
20+
21+
*We introduce Lumina-Image 2.0, an advanced text-to-image model that surpasses previous state-of-the-art methods across multiple benchmarks, while also shedding light on its potential to evolve into a generalist vision intelligence model. Lumina-Image 2.0 exhibits three key properties: (1) Unification – it adopts a unified architecture that treats text and image tokens as a joint sequence, enabling natural cross-modal interactions and facilitating task expansion. Besides, since high-quality captioners can provide semantically better-aligned text-image training pairs, we introduce a unified captioning system, UniCaptioner, which generates comprehensive and precise captions for the model. This not only accelerates model convergence but also enhances prompt adherence, variable-length prompt handling, and task generalization via prompt templates. (2) Efficiency – to improve the efficiency of the unified architecture, we develop a set of optimization techniques that improve semantic learning and fine-grained texture generation during training while incorporating inference-time acceleration strategies without compromising image quality. (3) Transparency – we open-source all training details, code, and models to ensure full reproducibility, aiming to bridge the gap between well-resourced closed-source research teams and independent developers.*
22+
23+
<Tip>
24+
25+
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
26+
27+
</Tip>
28+
29+
## Lumina2Text2ImgPipeline
30+
31+
[[autodoc]] Lumina2Text2ImgPipeline
32+
- all
33+
- __call__

src/diffusers/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
"Kandinsky3UNet",
119119
"LatteTransformer3DModel",
120120
"LTXVideoTransformer3DModel",
121+
"Lumina2Transformer2DModel",
121122
"LuminaNextDiT2DModel",
122123
"MochiTransformer3DModel",
123124
"ModelMixin",
@@ -338,6 +339,7 @@
338339
"LEditsPPPipelineStableDiffusionXL",
339340
"LTXImageToVideoPipeline",
340341
"LTXPipeline",
342+
"Lumina2Text2ImgPipeline",
341343
"LuminaText2ImgPipeline",
342344
"MarigoldDepthPipeline",
343345
"MarigoldNormalsPipeline",
@@ -634,6 +636,7 @@
634636
Kandinsky3UNet,
635637
LatteTransformer3DModel,
636638
LTXVideoTransformer3DModel,
639+
Lumina2Transformer2DModel,
637640
LuminaNextDiT2DModel,
638641
MochiTransformer3DModel,
639642
ModelMixin,
@@ -833,6 +836,7 @@
833836
LEditsPPPipelineStableDiffusionXL,
834837
LTXImageToVideoPipeline,
835838
LTXPipeline,
839+
Lumina2Text2ImgPipeline,
836840
LuminaText2ImgPipeline,
837841
MarigoldDepthPipeline,
838842
MarigoldNormalsPipeline,

src/diffusers/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
7373
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
7474
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
75+
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
7576
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
7677
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
7778
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
@@ -141,6 +142,7 @@
141142
HunyuanVideoTransformer3DModel,
142143
LatteTransformer3DModel,
143144
LTXVideoTransformer3DModel,
145+
Lumina2Transformer2DModel,
144146
LuminaNextDiT2DModel,
145147
MochiTransformer3DModel,
146148
OmniGenTransformer2DModel,

src/diffusers/models/attention.py

-1
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,6 @@ def __init__(
612612
ffn_dim_multiplier: Optional[float] = None,
613613
):
614614
super().__init__()
615-
inner_dim = int(2 * inner_dim / 3)
616615
# custom hidden_size factor multiplier
617616
if ffn_dim_multiplier is not None:
618617
inner_dim = int(ffn_dim_multiplier * inner_dim)

src/diffusers/models/normalization.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -219,14 +219,13 @@ def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine:
219219
4 * embedding_dim,
220220
bias=True,
221221
)
222-
self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
222+
self.norm = RMSNorm(embedding_dim, eps=norm_eps)
223223

224224
def forward(
225225
self,
226226
x: torch.Tensor,
227227
emb: Optional[torch.Tensor] = None,
228228
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
229-
# emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
230229
emb = self.linear(self.silu(emb))
231230
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
232231
x = self.norm(x) * (1 + scale_msa[:, None])
@@ -515,6 +514,16 @@ def forward(self, hidden_states):
515514
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
516515
if self.bias is not None:
517516
hidden_states = hidden_states + self.bias
517+
elif is_torch_version(">=", "2.4"):
518+
if self.weight is not None:
519+
# convert into half-precision if necessary
520+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
521+
hidden_states = hidden_states.to(self.weight.dtype)
522+
hidden_states = nn.functional.rms_norm(
523+
hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps
524+
)
525+
if self.bias is not None:
526+
hidden_states = hidden_states + self.bias
518527
else:
519528
input_dtype = hidden_states.dtype
520529
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)

src/diffusers/models/transformers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .transformer_flux import FluxTransformer2DModel
2222
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
2323
from .transformer_ltx import LTXVideoTransformer3DModel
24+
from .transformer_lumina2 import Lumina2Transformer2DModel
2425
from .transformer_mochi import MochiTransformer3DModel
2526
from .transformer_omnigen import OmniGenTransformer2DModel
2627
from .transformer_sd3 import SD3Transformer2DModel

src/diffusers/models/transformers/lumina_nextdit2d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898

9999
self.feed_forward = LuminaFeedForward(
100100
dim=dim,
101-
inner_dim=4 * dim,
101+
inner_dim=int(4 * 2 * dim / 3),
102102
multiple_of=multiple_of,
103103
ffn_dim_multiplier=ffn_dim_multiplier,
104104
)

0 commit comments

Comments
 (0)