diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 752219b4abd1..ba038486f21b 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -89,6 +89,8 @@
title: Kandinsky
- local: using-diffusers/ip_adapter
title: IP-Adapter
+ - local: using-diffusers/omnigen
+ title: OmniGen
- local: using-diffusers/pag
title: PAG
- local: using-diffusers/controlnet
@@ -292,6 +294,8 @@
title: LTXVideoTransformer3DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
+ - local: api/models/omnigen_transformer
+ title: OmniGenTransformer2DModel
- local: api/models/pixart_transformer2d
title: PixArtTransformer2DModel
- local: api/models/prior_transformer
@@ -448,6 +452,8 @@
title: MultiDiffusion
- local: api/pipelines/musicldm
title: MusicLDM
+ - local: api/pipelines/omnigen
+ title: OmniGen
- local: api/pipelines/pag
title: PAG
- local: api/pipelines/paint_by_example
diff --git a/docs/source/en/api/models/omnigen_transformer.md b/docs/source/en/api/models/omnigen_transformer.md
new file mode 100644
index 000000000000..ee700a04bdae
--- /dev/null
+++ b/docs/source/en/api/models/omnigen_transformer.md
@@ -0,0 +1,19 @@
+
+
+# OmniGenTransformer2DModel
+
+A Transformer model that accepts multimodal instructions to generate images for [OmniGen](https://github.com/VectorSpaceLab/OmniGen/).
+
+## OmniGenTransformer2DModel
+
+[[autodoc]] OmniGenTransformer2DModel
diff --git a/docs/source/en/api/pipelines/omnigen.md b/docs/source/en/api/pipelines/omnigen.md
new file mode 100644
index 000000000000..0b826f182edd
--- /dev/null
+++ b/docs/source/en/api/pipelines/omnigen.md
@@ -0,0 +1,106 @@
+
+
+# OmniGen
+
+[OmniGen: Unified Image Generation](https://arxiv.org/pdf/2409.11340) from BAAI, by Shitao Xiao, Yueze Wang, Junjie Zhou, Huaying Yuan, Xingrun Xing, Ruiran Yan, Chaofan Li, Shuting Wang, Tiejun Huang, Zheng Liu.
+
+The abstract from the paper is:
+
+*The emergence of Large Language Models (LLMs) has unified language
+generation tasks and revolutionized human-machine interaction.
+However, in the realm of image generation, a unified model capable of handling various tasks
+within a single framework remains largely unexplored. In
+this work, we introduce OmniGen, a new diffusion model
+for unified image generation. OmniGen is characterized
+by the following features: 1) Unification: OmniGen not
+only demonstrates text-to-image generation capabilities but
+also inherently supports various downstream tasks, such
+as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of
+OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion
+models, it is more user-friendly and can complete complex
+tasks end-to-end through instructions without the need for
+extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from
+learning in a unified format, OmniGen effectively transfers
+knowledge across different tasks, manages unseen tasks and
+domains, and exhibits novel capabilities. We also explore
+the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism.
+This work represents the first attempt at a general-purpose image generation model,
+and we will release our resources at https:
+//github.com/VectorSpaceLab/OmniGen to foster future advancements.*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1).
+
+
+## Inference
+
+First, load the pipeline:
+
+```python
+import torch
+from diffusers import OmniGenPipeline
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+```
+
+For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
+You can try setting the `height` and `width` parameters to generate images with different size.
+
+```py
+prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
+image = pipe(
+ prompt=prompt,
+ height=1024,
+ width=1024,
+ guidance_scale=3,
+ generator=torch.Generator(device="cpu").manual_seed(111),
+).images[0]
+image
+```
+
+OmniGen supports multimodal inputs.
+When the input includes an image, you need to add a placeholder `
<|image_1|>` in the text prompt to represent the image.
+It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
+
+```py
+prompt="
<|image_1|> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola."
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(222)).images[0]
+image
+```
+
+
+## OmniGenPipeline
+
+[[autodoc]] OmniGenPipeline
+ - all
+ - __call__
+
+
diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md
new file mode 100644
index 000000000000..a3d98e4e60cc
--- /dev/null
+++ b/docs/source/en/using-diffusers/omnigen.md
@@ -0,0 +1,314 @@
+
+# OmniGen
+
+OmniGen is an image generation model. Unlike existing text-to-image models, OmniGen is a single model designed to handle a variety of tasks (e.g., text-to-image, image editing, controllable generation). It has the following features:
+- Minimalist model architecture, consisting of only a VAE and a transformer module, for joint modeling of text and images.
+- Support for multimodal inputs. It can process any text-image mixed data as instructions for image generation, rather than relying solely on text.
+
+For more information, please refer to the [paper](https://arxiv.org/pdf/2409.11340).
+This guide will walk you through using OmniGen for various tasks and use cases.
+
+## Load model checkpoints
+Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
+
+```py
+import torch
+from diffusers import OmniGenPipeline
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+```
+
+
+
+## Text-to-image
+
+For text-to-image, pass a text prompt. By default, OmniGen generates a 1024x1024 image.
+You can try setting the `height` and `width` parameters to generate images with different size.
+
+```py
+import torch
+from diffusers import OmniGenPipeline
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
+image = pipe(
+ prompt=prompt,
+ height=1024,
+ width=1024,
+ guidance_scale=3,
+ generator=torch.Generator(device="cpu").manual_seed(111),
+).images[0]
+image
+```
+
+

+
+
+## Image edit
+
+OmniGen supports multimodal inputs.
+When the input includes an image, you need to add a placeholder `
<|image_1|>` in the text prompt to represent the image.
+It is recommended to enable `use_input_image_size_as_output` to keep the edited image the same size as the original image.
+
+```py
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt="
<|image_1|> Remove the woman's earrings. Replace the mug with a clear glass filled with sparkling iced cola."
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/t2i_woman_with_book.png")]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(222)).images[0]
+image
+```
+
+
+

+
original image
+
+
+

+
edited image
+
+
+
+OmniGen has some interesting features, such as visual reasoning, as shown in the example below.
+```py
+prompt="If the woman is thirsty, what should she take? Find it in the image and highlight it in blue.
<|image_1|>"
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(0)).images[0]
+image
+```
+
+

+
+
+
+## Controllable generation
+
+ OmniGen can handle several classic computer vision tasks.
+ As shown below, OmniGen can detect human skeletons in input images, which can be used as control conditions to generate new images.
+
+```py
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt="Detect the skeleton of human in this image:
<|image_1|>"
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
+image1 = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(333)).images[0]
+image1
+
+prompt="Generate a new photo using the following picture and text as conditions:
<|image_1|>\n A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him."
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/skeletal.png")]
+image2 = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(333)).images[0]
+image2
+```
+
+
+
+

+
original image
+
+
+

+
detected skeleton
+
+
+

+
skeleton to image
+
+
+
+
+OmniGen can also directly use relevant information from input images to generate new images.
+```py
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt="Following the pose of this image
<|image_1|>, generate a new photo: A young boy is sitting on a sofa in the library, holding a book. His hair is neatly combed, and a faint smile plays on his lips, with a few freckles scattered across his cheeks. The library is quiet, with rows of shelves filled with books stretching out behind him."
+input_images=[load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/edit.png")]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ guidance_scale=2,
+ img_guidance_scale=1.6,
+ use_input_image_size_as_output=True,
+ generator=torch.Generator(device="cpu").manual_seed(0)).images[0]
+image
+```
+
+
+

+
generated image
+
+
+
+
+## ID and object preserving
+
+OmniGen can generate multiple images based on the people and objects in the input image and supports inputting multiple images simultaneously.
+Additionally, OmniGen can extract desired objects from an image containing multiple objects based on instructions.
+
+```py
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt="A man and a woman are sitting at a classroom desk. The man is the man with yellow hair in
<|image_1|>. The woman is the woman on the left of
<|image_2|>"
+input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/3.png")
+input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/4.png")
+input_images=[input_image_1, input_image_2]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ height=1024,
+ width=1024,
+ guidance_scale=2.5,
+ img_guidance_scale=1.6,
+ generator=torch.Generator(device="cpu").manual_seed(666)).images[0]
+image
+```
+
+
+

+
input_image_1
+
+
+

+
input_image_2
+
+
+

+
generated image
+
+
+
+
+```py
+import torch
+from diffusers import OmniGenPipeline
+from diffusers.utils import load_image
+
+pipe = OmniGenPipeline.from_pretrained(
+ "Shitao/OmniGen-v1-diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+
+prompt="A woman is walking down the street, wearing a white long-sleeve blouse with lace details on the sleeves, paired with a blue pleated skirt. The woman is
<|image_1|>. The long-sleeve blouse and a pleated skirt are
<|image_2|>."
+input_image_1 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/emma.jpeg")
+input_image_2 = load_image("https://raw.githubusercontent.com/VectorSpaceLab/OmniGen/main/imgs/docs_img/dress.jpg")
+input_images=[input_image_1, input_image_2]
+image = pipe(
+ prompt=prompt,
+ input_images=input_images,
+ height=1024,
+ width=1024,
+ guidance_scale=2.5,
+ img_guidance_scale=1.6,
+ generator=torch.Generator(device="cpu").manual_seed(666)).images[0]
+image
+```
+
+
+
+

+
person image
+
+
+

+
clothe image
+
+
+

+
generated image
+
+
+
+
+## Optimization when inputting multiple images
+
+For text-to-image task, OmniGen requires minimal memory and time costs (9GB memory and 31s for a 1024x1024 image on A800 GPU).
+However, when using input images, the computational cost increases.
+
+Here are some guidelines to help you reduce computational costs when inputting multiple images. The experiments are conducted on an A800 GPU with two input images.
+
+Like other pipelines, you can reduce memory usage by offloading the model: `pipe.enable_model_cpu_offload()` or `pipe.enable_sequential_cpu_offload() `.
+In OmniGen, you can also decrease computational overhead by reducing the `max_input_image_size`.
+The memory consumption for different image sizes is shown in the table below:
+
+| Method | Memory Usage |
+|---------------------------|--------------|
+| max_input_image_size=1024 | 40GB |
+| max_input_image_size=512 | 17GB |
+| max_input_image_size=256 | 14GB |
+
+
+
diff --git a/scripts/convert_omnigen_to_diffusers.py b/scripts/convert_omnigen_to_diffusers.py
new file mode 100644
index 000000000000..96bc935633f0
--- /dev/null
+++ b/scripts/convert_omnigen_to_diffusers.py
@@ -0,0 +1,203 @@
+import argparse
+import os
+
+import torch
+from huggingface_hub import snapshot_download
+from safetensors.torch import load_file
+from transformers import AutoTokenizer
+
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
+
+
+def main(args):
+ # checkpoint from https://huggingface.co/Shitao/OmniGen-v1
+
+ if not os.path.exists(args.origin_ckpt_path):
+ print("Model not found, downloading...")
+ cache_folder = os.getenv("HF_HUB_CACHE")
+ args.origin_ckpt_path = snapshot_download(
+ repo_id=args.origin_ckpt_path,
+ cache_dir=cache_folder,
+ ignore_patterns=["flax_model.msgpack", "rust_model.ot", "tf_model.h5", "model.pt"],
+ )
+ print(f"Downloaded model to {args.origin_ckpt_path}")
+
+ ckpt = os.path.join(args.origin_ckpt_path, "model.safetensors")
+ ckpt = load_file(ckpt, device="cpu")
+
+ mapping_dict = {
+ "pos_embed": "patch_embedding.pos_embed",
+ "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight",
+ "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias",
+ "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight",
+ "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias",
+ "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
+ "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
+ "final_layer.linear.weight": "proj_out.weight",
+ "final_layer.linear.bias": "proj_out.bias",
+ "time_token.mlp.0.weight": "time_token.linear_1.weight",
+ "time_token.mlp.0.bias": "time_token.linear_1.bias",
+ "time_token.mlp.2.weight": "time_token.linear_2.weight",
+ "time_token.mlp.2.bias": "time_token.linear_2.bias",
+ "t_embedder.mlp.0.weight": "t_embedder.linear_1.weight",
+ "t_embedder.mlp.0.bias": "t_embedder.linear_1.bias",
+ "t_embedder.mlp.2.weight": "t_embedder.linear_2.weight",
+ "t_embedder.mlp.2.bias": "t_embedder.linear_2.bias",
+ "llm.embed_tokens.weight": "embed_tokens.weight",
+ }
+
+ converted_state_dict = {}
+ for k, v in ckpt.items():
+ if k in mapping_dict:
+ converted_state_dict[mapping_dict[k]] = v
+ elif "qkv" in k:
+ to_q, to_k, to_v = v.chunk(3)
+ converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_q.weight"] = to_q
+ converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_k.weight"] = to_k
+ converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_v.weight"] = to_v
+ elif "o_proj" in k:
+ converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_out.0.weight"] = v
+ else:
+ converted_state_dict[k[4:]] = v
+
+ transformer = OmniGenTransformer2DModel(
+ rope_scaling={
+ "long_factor": [
+ 1.0299999713897705,
+ 1.0499999523162842,
+ 1.0499999523162842,
+ 1.0799999237060547,
+ 1.2299998998641968,
+ 1.2299998998641968,
+ 1.2999999523162842,
+ 1.4499999284744263,
+ 1.5999999046325684,
+ 1.6499998569488525,
+ 1.8999998569488525,
+ 2.859999895095825,
+ 3.68999981880188,
+ 5.419999599456787,
+ 5.489999771118164,
+ 5.489999771118164,
+ 9.09000015258789,
+ 11.579999923706055,
+ 15.65999984741211,
+ 15.769999504089355,
+ 15.789999961853027,
+ 18.360000610351562,
+ 21.989999771118164,
+ 23.079999923706055,
+ 30.009998321533203,
+ 32.35000228881836,
+ 32.590003967285156,
+ 35.56000518798828,
+ 39.95000457763672,
+ 53.840003967285156,
+ 56.20000457763672,
+ 57.95000457763672,
+ 59.29000473022461,
+ 59.77000427246094,
+ 59.920005798339844,
+ 61.190006256103516,
+ 61.96000671386719,
+ 62.50000762939453,
+ 63.3700065612793,
+ 63.48000717163086,
+ 63.48000717163086,
+ 63.66000747680664,
+ 63.850006103515625,
+ 64.08000946044922,
+ 64.760009765625,
+ 64.80001068115234,
+ 64.81001281738281,
+ 64.81001281738281,
+ ],
+ "short_factor": [
+ 1.05,
+ 1.05,
+ 1.05,
+ 1.1,
+ 1.1,
+ 1.1,
+ 1.2500000000000002,
+ 1.2500000000000002,
+ 1.4000000000000004,
+ 1.4500000000000004,
+ 1.5500000000000005,
+ 1.8500000000000008,
+ 1.9000000000000008,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.000000000000001,
+ 2.1000000000000005,
+ 2.1000000000000005,
+ 2.2,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3499999999999996,
+ 2.3999999999999995,
+ 2.3999999999999995,
+ 2.6499999999999986,
+ 2.6999999999999984,
+ 2.8999999999999977,
+ 2.9499999999999975,
+ 3.049999999999997,
+ 3.049999999999997,
+ 3.049999999999997,
+ ],
+ "type": "su",
+ },
+ patch_size=2,
+ in_channels=4,
+ pos_embed_max_size=192,
+ )
+ transformer.load_state_dict(converted_state_dict, strict=True)
+ transformer.to(torch.bfloat16)
+
+ num_model_params = sum(p.numel() for p in transformer.parameters())
+ print(f"Total number of transformer parameters: {num_model_params}")
+
+ scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1)
+
+ vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32)
+
+ tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path)
+
+ pipeline = OmniGenPipeline(tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler)
+ pipeline.save_pretrained(args.dump_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--origin_ckpt_path",
+ default="Shitao/OmniGen-v1",
+ type=str,
+ required=False,
+ help="Path to the checkpoint to convert.",
+ )
+
+ parser.add_argument(
+ "--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline."
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index c36226225ad4..32386fab9a3b 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -124,6 +124,7 @@
"MotionAdapter",
"MultiAdapter",
"MultiControlNetModel",
+ "OmniGenTransformer2DModel",
"PixArtTransformer2DModel",
"PriorTransformer",
"SanaTransformer2DModel",
@@ -342,6 +343,7 @@
"MarigoldNormalsPipeline",
"MochiPipeline",
"MusicLDMPipeline",
+ "OmniGenPipeline",
"PaintByExamplePipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
@@ -638,6 +640,7 @@
MotionAdapter,
MultiAdapter,
MultiControlNetModel,
+ OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SanaTransformer2DModel,
@@ -835,6 +838,7 @@
MarigoldNormalsPipeline,
MochiPipeline,
MusicLDMPipeline,
+ OmniGenPipeline,
PaintByExamplePipeline,
PIAPipeline,
PixArtAlphaPipeline,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 57a34609d28e..eb09765b78cd 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -73,6 +73,7 @@
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
+ _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -142,6 +143,7 @@
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
+ OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
SanaTransformer2DModel,
diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py
index 7db4d3d17d2f..1918c24d2be7 100644
--- a/src/diffusers/models/normalization.py
+++ b/src/diffusers/models/normalization.py
@@ -71,7 +71,7 @@ def forward(
if self.chunk_dim == 1:
# This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
- # other if-branch. This branch is specific to CogVideoX for now.
+ # other if-branch. This branch is specific to CogVideoX and OmniGen for now.
shift, scale = temb.chunk(2, dim=1)
shift = shift[:, None, :]
scale = scale[:, None, :]
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 77e1698b8fc2..aa09949fc398 100644
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -22,5 +22,6 @@
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_mochi import MochiTransformer3DModel
+ from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py
new file mode 100644
index 000000000000..0774a3f2a6ee
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_omnigen.py
@@ -0,0 +1,699 @@
+# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers
+from ..attention_processor import Attention, AttentionProcessor
+from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNorm, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class OmniGenFeedForward(nn.Module):
+ r"""
+ A feed-forward layer for OmniGen.
+
+ Parameters:
+ hidden_size (`int`):
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
+ hidden representations.
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ intermediate_size: int,
+ ):
+ super().__init__()
+ self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+
+ self.activation_fn = nn.SiLU()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ up_states = self.gate_up_proj(hidden_states)
+
+ gate, up_states = up_states.chunk(2, dim=-1)
+ up_states = up_states * self.activation_fn(gate)
+
+ return self.down_proj(up_states)
+
+
+class OmniGenPatchEmbed(nn.Module):
+ """2D Image to Patch Embedding with support for OmniGen."""
+
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 4,
+ embed_dim: int = 768,
+ bias: bool = True,
+ interpolation_scale: float = 1,
+ pos_embed_max_size: int = 192,
+ base_size: int = 64,
+ ):
+ super().__init__()
+
+ self.output_image_proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ self.input_image_proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+
+ self.patch_size = patch_size
+ self.interpolation_scale = interpolation_scale
+ self.pos_embed_max_size = pos_embed_max_size
+
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim,
+ self.pos_embed_max_size,
+ base_size=base_size,
+ interpolation_scale=self.interpolation_scale,
+ output_type="pt",
+ )
+ self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=True)
+
+ def cropped_pos_embed(self, height, width):
+ """Crops positional embeddings for SD3 compatibility."""
+ if self.pos_embed_max_size is None:
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
+
+ height = height // self.patch_size
+ width = width // self.patch_size
+ if height > self.pos_embed_max_size:
+ raise ValueError(
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+ if width > self.pos_embed_max_size:
+ raise ValueError(
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+
+ top = (self.pos_embed_max_size - height) // 2
+ left = (self.pos_embed_max_size - width) // 2
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
+ return spatial_pos_embed
+
+ def patch_embeddings(self, latent, is_input_image: bool):
+ if is_input_image:
+ latent = self.input_image_proj(latent)
+ else:
+ latent = self.output_image_proj(latent)
+ latent = latent.flatten(2).transpose(1, 2)
+ return latent
+
+ def forward(self, latent: torch.Tensor, is_input_image: bool, padding_latent: torch.Tensor = None):
+ """
+ Args:
+ latent: encoded image latents
+ is_input_image: use input_image_proj or output_image_proj
+ padding_latent:
+ When sizes of target images are inconsistent, use `padding_latent` to maintain consistent sequence
+ length.
+
+ Returns: torch.Tensor
+
+ """
+ if isinstance(latent, list):
+ if padding_latent is None:
+ padding_latent = [None] * len(latent)
+ patched_latents = []
+ for sub_latent, padding in zip(latent, padding_latent):
+ height, width = sub_latent.shape[-2:]
+ sub_latent = self.patch_embeddings(sub_latent, is_input_image)
+ pos_embed = self.cropped_pos_embed(height, width)
+ sub_latent = sub_latent + pos_embed
+ if padding is not None:
+ sub_latent = torch.cat([sub_latent, padding.to(sub_latent.device)], dim=-2)
+ patched_latents.append(sub_latent)
+ else:
+ height, width = latent.shape[-2:]
+ pos_embed = self.cropped_pos_embed(height, width)
+ latent = self.patch_embeddings(latent, is_input_image)
+ patched_latents = latent + pos_embed
+
+ return patched_latents
+
+
+class OmniGenSuScaledRotaryEmbedding(nn.Module):
+ def __init__(
+ self, dim, max_position_embeddings=131072, original_max_position_embeddings=4096, base=10000, rope_scaling=None
+ ):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
+ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
+
+ self.short_factor = rope_scaling["short_factor"]
+ self.long_factor = rope_scaling["long_factor"]
+ self.original_max_position_embeddings = original_max_position_embeddings
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.original_max_position_embeddings:
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
+ else:
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
+
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
+
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
+ if scale <= 1.0:
+ scaling_factor = 1.0
+ else:
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
+
+ cos = emb.cos() * scaling_factor
+ sin = emb.sin() * scaling_factor
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def apply_rotary_emb(
+ x: torch.Tensor,
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
+ tensors contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x (`torch.Tensor`):
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+ """
+
+ cos, sin = freqs_cis # [S, D]
+ if len(cos.shape) == 2:
+ cos = cos[None, None]
+ sin = sin[None, None]
+ elif len(cos.shape) == 3:
+ cos = cos[:, None]
+ sin = sin[:, None]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+
+ # Rotates half the hidden dims of the input. this rorate function is widely used in LLM, e.g. Llama, Phi3, etc.
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ x_rotated = torch.cat((-x2, x1), dim=-1)
+
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+ return out
+
+
+class OmniGenAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the OmniGen model.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ # Get Query-Key-Value Pair
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ bsz, q_len, query_dim = query.size()
+ inner_dim = key.shape[-1]
+ head_dim = query_dim // attn.heads
+ dtype = query.dtype
+
+ # Get key-value heads
+ kv_heads = inner_dim // head_dim
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ query, key = query.to(dtype), key.to(dtype)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
+ hidden_states = hidden_states.transpose(1, 2).to(dtype)
+ hidden_states = hidden_states.reshape(bsz, q_len, attn.out_dim)
+ hidden_states = attn.to_out[0](hidden_states)
+ return hidden_states
+
+
+class OmniGenBlock(nn.Module):
+ """
+ A LuminaNextDiTBlock for LuminaNextDiT2DModel.
+
+ Parameters:
+ hidden_size (`int`): Embedding dimension of the input features.
+ num_attention_heads (`int`): Number of attention heads.
+ num_key_value_heads (`int`):
+ Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
+ intermediate_size (`int`): size of intermediate layer.
+ rms_norm_eps (`float`): The eps for norm layer.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_attention_heads: int,
+ num_key_value_heads: int,
+ intermediate_size: int,
+ rms_norm_eps: float,
+ ) -> None:
+ super().__init__()
+
+ self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
+ self.self_attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=hidden_size,
+ dim_head=hidden_size // num_attention_heads,
+ heads=num_attention_heads,
+ kv_heads=num_key_value_heads,
+ bias=False,
+ out_dim=hidden_size,
+ out_bias=False,
+ processor=OmniGenAttnProcessor2_0(),
+ )
+ self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps)
+ self.mlp = OmniGenFeedForward(hidden_size, intermediate_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ image_rotary_emb: torch.Tensor,
+ ):
+ """
+ Perform a forward pass through the LuminaNextDiTBlock.
+
+ Parameters:
+ hidden_states (`torch.Tensor`): The input of hidden_states for LuminaNextDiTBlock.
+ attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask.
+ image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies.
+ """
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ attn_outputs = self.self_attn(
+ hidden_states=hidden_states,
+ encoder_hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = residual + attn_outputs
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states
+
+
+class OmniGenTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ The Transformer model introduced in OmniGen.
+
+ Reference: https://arxiv.org/pdf/2409.11340
+
+ Parameters:
+ hidden_size (`int`, *optional*, defaults to 3072):
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
+ hidden representations.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-5): eps for RMSNorm layer.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ The number of attention heads in each attention layer. This parameter specifies how many separate attention
+ mechanisms are used.
+ num_kv_heads (`int`, *optional*, defaults to 32):
+ The number of key-value heads in the attention mechanism, if different from the number of attention heads.
+ If None, it defaults to num_attention_heads.
+ intermediate_size (`int`, *optional*, defaults to 8192): dimension of the intermediate layer in FFN
+ num_layers (`int`, *optional*, default to 32):
+ The number of layers in the model. This defines the depth of the neural network.
+ pad_token_id (`int`, *optional*, default to 32000):
+ id for pad token
+ vocab_size (`int`, *optional*, default to 32064):
+ size of vocabulary
+ patch_size (`int`, defaults to 2): Patch size to turn the input data into small patches.
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
+ pos_embed_max_size (`int`, *optional*, defaults to 192): The max size of pos emb.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["OmniGenBlock"]
+
+ @register_to_config
+ def __init__(
+ self,
+ hidden_size: int = 3072,
+ rms_norm_eps: float = 1e-05,
+ num_attention_heads: int = 32,
+ num_key_value_heads: int = 32,
+ intermediate_size: int = 8192,
+ num_layers: int = 32,
+ pad_token_id: int = 32000,
+ vocab_size: int = 32064,
+ max_position_embeddings: int = 131072,
+ original_max_position_embeddings: int = 4096,
+ rope_base: int = 10000,
+ rope_scaling: Dict = None,
+ patch_size=2,
+ in_channels=4,
+ pos_embed_max_size: int = 192,
+ time_step_dim: int = 256,
+ flip_sin_to_cos: bool = True,
+ downscale_freq_shift: int = 0,
+ timestep_activation_fn: str = "silu",
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.patch_size = patch_size
+ self.pos_embed_max_size = pos_embed_max_size
+
+ self.patch_embedding = OmniGenPatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=hidden_size,
+ pos_embed_max_size=pos_embed_max_size,
+ )
+
+ self.time_proj = Timesteps(time_step_dim, flip_sin_to_cos, downscale_freq_shift)
+ self.time_token = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
+ self.t_embedder = TimestepEmbedding(time_step_dim, hidden_size, timestep_activation_fn)
+
+ self.norm_out = AdaLayerNorm(hidden_size, norm_elementwise_affine=False, norm_eps=1e-6, chunk_dim=1)
+ self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.embed_tokens = nn.Embedding(vocab_size, hidden_size, pad_token_id)
+ self.rotary_emb = OmniGenSuScaledRotaryEmbedding(
+ hidden_size // num_attention_heads,
+ max_position_embeddings=max_position_embeddings,
+ original_max_position_embeddings=original_max_position_embeddings,
+ base=rope_base,
+ rope_scaling=rope_scaling,
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ OmniGenBlock(
+ hidden_size,
+ num_attention_heads,
+ num_key_value_heads,
+ intermediate_size,
+ rms_norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm = RMSNorm(hidden_size, eps=rms_norm_eps)
+
+ self.gradient_checkpointing = False
+
+ def unpatchify(self, x, h, w):
+ """
+ x: (N, T, patch_size**2 * C) imgs: (N, H, W, C)
+ """
+ c = self.out_channels
+
+ x = x.reshape(
+ shape=(x.shape[0], h // self.patch_size, w // self.patch_size, self.patch_size, self.patch_size, c)
+ )
+ x = torch.einsum("nhwpqc->nchpwq", x)
+ imgs = x.reshape(shape=(x.shape[0], c, h, w))
+ return imgs
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[OmniGenAttnProcessor2_0, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def get_multimodal_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ input_img_latents: List[torch.Tensor],
+ input_image_sizes: Dict,
+ ):
+ """
+ get the multi-modal conditional embeddings
+
+ Args:
+ input_ids: a sequence of text id
+ input_img_latents: continues embedding of input images
+ input_image_sizes: the index of the input image in the input_ids sequence.
+
+ Returns: torch.Tensor
+
+ """
+ input_img_latents = [x.to(self.dtype) for x in input_img_latents]
+ condition_tokens = None
+ if input_ids is not None:
+ condition_tokens = self.embed_tokens(input_ids)
+ input_img_inx = 0
+ if input_img_latents is not None:
+ input_image_tokens = self.patch_embedding(input_img_latents, is_input_image=True)
+
+ for b_inx in input_image_sizes.keys():
+ for start_inx, end_inx in input_image_sizes[b_inx]:
+ # replace the placeholder in text tokens with the image embedding.
+ condition_tokens[b_inx, start_inx:end_inx] = input_image_tokens[input_img_inx].to(
+ condition_tokens.dtype
+ )
+ input_img_inx += 1
+
+ return condition_tokens
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.FloatTensor],
+ input_ids: torch.Tensor,
+ input_img_latents: List[torch.Tensor],
+ input_image_sizes: Dict[int, List[int]],
+ attention_mask: torch.Tensor,
+ position_ids: torch.Tensor,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ):
+ """
+ The [`OmniGenTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ timestep (`torch.FloatTensor`):
+ Used to indicate denoising step.
+ input_ids (`torch.LongTensor`):
+ token ids
+ input_img_latents (`torch.Tensor`):
+ encoded image latents by VAE
+ input_image_sizes (`dict`):
+ the indices of the input_img_latents in the input_ids
+ attention_mask (`torch.Tensor`):
+ mask for self-attention
+ position_ids (`torch.LongTensor`):
+ id to represent position
+ past_key_values (`transformers.cache_utils.Cache`):
+ previous key and value states
+ offload_transformer_block (`bool`, *optional*, defaults to `True`):
+ offload transformer block to cpu
+ attention_kwargs: (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`OmniGen2DModelOutput`] instead of a plain tuple.
+
+ Returns:
+ If `return_dict` is True, an [`OmniGen2DModelOutput`] is returned, otherwise a `tuple` where the first
+ element is the sample tensor.
+
+ """
+
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+ height, width = hidden_states.size()[-2:]
+ hidden_states = self.patch_embedding(hidden_states, is_input_image=False)
+ num_tokens_for_output_image = hidden_states.size(1)
+
+ time_token = self.time_token(self.time_proj(timestep).to(hidden_states.dtype)).unsqueeze(1)
+
+ condition_tokens = self.get_multimodal_embeddings(
+ input_ids=input_ids,
+ input_img_latents=input_img_latents,
+ input_image_sizes=input_image_sizes,
+ )
+ if condition_tokens is not None:
+ inputs_embeds = torch.cat([condition_tokens, time_token, hidden_states], dim=1)
+ else:
+ inputs_embeds = torch.cat([time_token, hidden_states], dim=1)
+
+ batch_size, seq_length = inputs_embeds.shape[:2]
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if attention_mask is not None and attention_mask.dim() == 3:
+ dtype = inputs_embeds.dtype
+ min_dtype = torch.finfo(dtype).min
+ attention_mask = (1 - attention_mask) * min_dtype
+ attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
+ else:
+ raise Exception("attention_mask parameter was unavailable or invalid")
+
+ hidden_states = inputs_embeds
+
+ image_rotary_emb = self.rotary_emb(hidden_states, position_ids)
+ for decoder_layer in self.layers:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ decoder_layer, hidden_states, attention_mask, image_rotary_emb
+ )
+ else:
+ hidden_states = decoder_layer(
+ hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb
+ )
+
+ hidden_states = self.norm(hidden_states)
+
+ hidden_states = hidden_states[:, -num_tokens_for_output_image:]
+ timestep_proj = self.time_proj(timestep)
+ temb = self.t_embedder(timestep_proj.type_as(hidden_states))
+ hidden_states = self.norm_out(hidden_states, temb=temb)
+ hidden_states = self.proj_out(hidden_states)
+ output = self.unpatchify(hidden_states, height, width)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 5829cf495dcc..d9869a8b406d 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -264,6 +264,7 @@
)
_import_structure["mochi"] = ["MochiPipeline"]
_import_structure["musicldm"] = ["MusicLDMPipeline"]
+ _import_structure["omnigen"] = ["OmniGenPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
@@ -602,6 +603,7 @@
)
from .mochi import MochiPipeline
from .musicldm import MusicLDMPipeline
+ from .omnigen import OmniGenPipeline
from .pag import (
AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline,
diff --git a/src/diffusers/pipelines/consisid/pipeline_consisid.py b/src/diffusers/pipelines/consisid/pipeline_consisid.py
index 0d4891cf17d7..1a99c2a0e9ee 100644
--- a/src/diffusers/pipelines/consisid/pipeline_consisid.py
+++ b/src/diffusers/pipelines/consisid/pipeline_consisid.py
@@ -48,9 +48,14 @@
>>> from huggingface_hub import snapshot_download
>>> snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir="BestWishYsh/ConsisID-preview")
- >>> face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = (
- ... prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
- ... )
+ >>> (
+ ... face_helper_1,
+ ... face_helper_2,
+ ... face_clip_model,
+ ... face_main_model,
+ ... eva_transform_mean,
+ ... eva_transform_std,
+ ... ) = prepare_face_models("BestWishYsh/ConsisID-preview", device="cuda", dtype=torch.bfloat16)
>>> pipe = ConsisIDPipeline.from_pretrained("BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
diff --git a/src/diffusers/pipelines/omnigen/__init__.py b/src/diffusers/pipelines/omnigen/__init__.py
new file mode 100644
index 000000000000..557e7c08dc22
--- /dev/null
+++ b/src/diffusers/pipelines/omnigen/__init__.py
@@ -0,0 +1,50 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_omnigen"] = ["OmniGenPipeline"]
+
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_omnigen import OmniGenPipeline
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
new file mode 100644
index 000000000000..41bfab5e3e04
--- /dev/null
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -0,0 +1,530 @@
+# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import LlamaTokenizer
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import OmniGenTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from .processor_omnigen import OmniGenMultiModalProcessor
+
+
+if is_torch_xla_available():
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import OmniGenPipeline
+
+ >>> pipe = OmniGenPipeline.from_pretrained("Shitao/OmniGen-v1-diffusers", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
+ >>> image.save("t2i.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class OmniGenPipeline(
+ DiffusionPipeline,
+):
+ r"""
+ The OmniGen pipeline for multimodal-to-image generation.
+
+ Reference: https://arxiv.org/pdf/2409.11340
+
+ Args:
+ transformer ([`OmniGenTransformer2DModel`]):
+ Autoregressive Transformer architecture for OmniGen.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ tokenizer (`LlamaTokenizer`):
+ Text tokenizer of class.
+ [LlamaTokenizer](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaTokenizer).
+ """
+
+ model_cpu_offload_seq = "transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents"]
+
+ def __init__(
+ self,
+ transformer: OmniGenTransformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ tokenizer: LlamaTokenizer,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) is not None else 8
+ )
+ # OmniGen latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ self.multimodal_processor = OmniGenMultiModalProcessor(tokenizer, max_image_size=1024)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 120000
+ )
+ self.default_sample_size = 128
+
+ def encode_input_images(
+ self,
+ input_pixel_values: List[torch.Tensor],
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ get the continue embedding of input images by VAE
+
+ Args:
+ input_pixel_values: normlized pixel of input images
+ device:
+ Returns: torch.Tensor
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.vae.dtype
+
+ input_img_latents = []
+ for img in input_pixel_values:
+ img = self.vae.encode(img.to(device, dtype)).latent_dist.sample().mul_(self.vae.config.scaling_factor)
+ input_img_latents.append(img)
+ return input_img_latents
+
+ def check_inputs(
+ self,
+ prompt,
+ input_images,
+ height,
+ width,
+ use_input_image_size_as_output,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if input_images is not None:
+ if len(input_images) != len(prompt):
+ raise ValueError(
+ f"The number of prompts: {len(prompt)} does not match the number of input images: {len(input_images)}."
+ )
+ for i in range(len(input_images)):
+ if input_images[i] is not None:
+ if not all(f"
<|image_{k + 1}|>" in prompt[i] for k in range(len(input_images[i]))):
+ raise ValueError(
+ f"prompt `{prompt[i]}` doesn't have enough placeholders for the input images `{input_images[i]}`"
+ )
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if use_input_image_size_as_output:
+ if input_images is None or input_images[0] is None:
+ raise ValueError(
+ "`use_input_image_size_as_output` is set to True, but no input image was found. If you are performing a text-to-image task, please set it to False."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ input_images: Union[PipelineImageInput, List[PipelineImageInput]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ max_input_image_size: int = 1024,
+ timesteps: List[int] = None,
+ guidance_scale: float = 2.5,
+ img_guidance_scale: float = 1.6,
+ use_input_image_size_as_output: bool = False,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 120000,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If the input includes images, need to add
+ placeholders `
<|image_i|>` in the prompt to indicate the position of the i-th images.
+ input_images (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
+ The list of input images. We will replace the "<|image_i|>" in prompt with the i-th image in list.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ max_input_image_size (`int`, *optional*, defaults to 1024):
+ the maximum size of input image, which will be used to crop the input image to the maximum size
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 2.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ img_guidance_scale (`float`, *optional*, defaults to 1.6):
+ Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
+ use_input_image_size_as_output (bool, defaults to False):
+ whether to use the input image size as the output image size, which can be used for single-image input,
+ e.g., image editing task
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ num_cfg = 2 if input_images is not None else 1
+ use_img_cfg = True if input_images is not None else False
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ input_images = [input_images]
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ input_images,
+ height,
+ width,
+ use_input_image_size_as_output,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ batch_size = len(prompt)
+ device = self._execution_device
+
+ # 3. process multi-modal instructions
+ if max_input_image_size != self.multimodal_processor.max_image_size:
+ self.multimodal_processor.reset_max_image_size(max_image_size=max_input_image_size)
+ processed_data = self.multimodal_processor(
+ prompt,
+ input_images,
+ height=height,
+ width=width,
+ use_img_cfg=use_img_cfg,
+ use_input_image_size_as_output=use_input_image_size_as_output,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ processed_data["input_ids"] = processed_data["input_ids"].to(device)
+ processed_data["attention_mask"] = processed_data["attention_mask"].to(device)
+ processed_data["position_ids"] = processed_data["position_ids"].to(device)
+
+ # 4. Encode input images
+ input_img_latents = self.encode_input_images(processed_data["input_pixel_values"], device=device)
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1, 0, num_inference_steps + 1)[:num_inference_steps]
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
+ )
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare latents.
+ if use_input_image_size_as_output:
+ height, width = processed_data["input_pixel_values"][0].shape[-2:]
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ self.transformer.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 8. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * (num_cfg + 1))
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ input_ids=processed_data["input_ids"],
+ input_img_latents=input_img_latents,
+ input_image_sizes=processed_data["input_image_sizes"],
+ attention_mask=processed_data["attention_mask"],
+ position_ids=processed_data["position_ids"],
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if num_cfg == 2:
+ cond, uncond, img_cond = torch.split(noise_pred, len(noise_pred) // 3, dim=0)
+ noise_pred = uncond + img_guidance_scale * (img_cond - uncond) + guidance_scale * (cond - img_cond)
+ else:
+ cond, uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0)
+ noise_pred = uncond + guidance_scale * (cond - uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ progress_bar.update()
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents = latents / self.vae.config.scaling_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ else:
+ image = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py
new file mode 100644
index 000000000000..75d272ac5140
--- /dev/null
+++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py
@@ -0,0 +1,327 @@
+# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import Dict, List
+
+import numpy as np
+import torch
+from PIL import Image
+from torchvision import transforms
+
+
+def crop_image(pil_image, max_image_size):
+ """
+ Crop the image so that its height and width does not exceed `max_image_size`, while ensuring both the height and
+ width are multiples of 16.
+ """
+ while min(*pil_image.size) >= 2 * max_image_size:
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
+
+ if max(*pil_image.size) > max_image_size:
+ scale = max_image_size / max(*pil_image.size)
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
+
+ if min(*pil_image.size) < 16:
+ scale = 16 / min(*pil_image.size)
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
+
+ arr = np.array(pil_image)
+ crop_y1 = (arr.shape[0] % 16) // 2
+ crop_y2 = arr.shape[0] % 16 - crop_y1
+
+ crop_x1 = (arr.shape[1] % 16) // 2
+ crop_x2 = arr.shape[1] % 16 - crop_x1
+
+ arr = arr[crop_y1 : arr.shape[0] - crop_y2, crop_x1 : arr.shape[1] - crop_x2]
+ return Image.fromarray(arr)
+
+
+class OmniGenMultiModalProcessor:
+ def __init__(self, text_tokenizer, max_image_size: int = 1024):
+ self.text_tokenizer = text_tokenizer
+ self.max_image_size = max_image_size
+
+ self.image_transform = transforms.Compose(
+ [
+ transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+
+ self.collator = OmniGenCollator()
+
+ def reset_max_image_size(self, max_image_size):
+ self.max_image_size = max_image_size
+ self.image_transform = transforms.Compose(
+ [
+ transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
+ ]
+ )
+
+ def process_image(self, image):
+ if isinstance(image, str):
+ image = Image.open(image).convert("RGB")
+ return self.image_transform(image)
+
+ def process_multi_modal_prompt(self, text, input_images):
+ text = self.add_prefix_instruction(text)
+ if input_images is None or len(input_images) == 0:
+ model_inputs = self.text_tokenizer(text)
+ return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
+
+ pattern = r"<\|image_\d+\|>"
+ prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
+
+ for i in range(1, len(prompt_chunks)):
+ if prompt_chunks[i][0] == 1:
+ prompt_chunks[i] = prompt_chunks[i][1:]
+
+ image_tags = re.findall(pattern, text)
+ image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
+
+ unique_image_ids = sorted(set(image_ids))
+ assert unique_image_ids == list(
+ range(1, len(unique_image_ids) + 1)
+ ), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
+ # total images must be the same as the number of image tags
+ assert (
+ len(unique_image_ids) == len(input_images)
+ ), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
+
+ input_images = [input_images[x - 1] for x in image_ids]
+
+ all_input_ids = []
+ img_inx = []
+ for i in range(len(prompt_chunks)):
+ all_input_ids.extend(prompt_chunks[i])
+ if i != len(prompt_chunks) - 1:
+ start_inx = len(all_input_ids)
+ size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
+ img_inx.append([start_inx, start_inx + size])
+ all_input_ids.extend([0] * size)
+
+ return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
+
+ def add_prefix_instruction(self, prompt):
+ user_prompt = "<|user|>\n"
+ generation_prompt = "Generate an image according to the following instructions\n"
+ assistant_prompt = "<|assistant|>\n<|diffusion|>"
+ prompt_suffix = "<|end|>\n"
+ prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
+ return prompt
+
+ def __call__(
+ self,
+ instructions: List[str],
+ input_images: List[List[str]] = None,
+ height: int = 1024,
+ width: int = 1024,
+ negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
+ use_img_cfg: bool = True,
+ separate_cfg_input: bool = False,
+ use_input_image_size_as_output: bool = False,
+ num_images_per_prompt: int = 1,
+ ) -> Dict:
+ if isinstance(instructions, str):
+ instructions = [instructions]
+ input_images = [input_images]
+
+ input_data = []
+ for i in range(len(instructions)):
+ cur_instruction = instructions[i]
+ cur_input_images = None if input_images is None else input_images[i]
+ if cur_input_images is not None and len(cur_input_images) > 0:
+ cur_input_images = [self.process_image(x) for x in cur_input_images]
+ else:
+ cur_input_images = None
+ assert "
<|image_1|>" not in cur_instruction
+
+ mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
+
+ neg_mllm_input, img_cfg_mllm_input = None, None
+ neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
+ if use_img_cfg:
+ if cur_input_images is not None and len(cur_input_images) >= 1:
+ img_cfg_prompt = [f"
<|image_{i + 1}|>" for i in range(len(cur_input_images))]
+ img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
+ else:
+ img_cfg_mllm_input = neg_mllm_input
+
+ for _ in range(num_images_per_prompt):
+ if use_input_image_size_as_output:
+ input_data.append(
+ (
+ mllm_input,
+ neg_mllm_input,
+ img_cfg_mllm_input,
+ [mllm_input["pixel_values"][0].size(-2), mllm_input["pixel_values"][0].size(-1)],
+ )
+ )
+ else:
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
+
+ return self.collator(input_data)
+
+
+class OmniGenCollator:
+ def __init__(self, pad_token_id=2, hidden_size=3072):
+ self.pad_token_id = pad_token_id
+ self.hidden_size = hidden_size
+
+ def create_position(self, attention_mask, num_tokens_for_output_images):
+ position_ids = []
+ text_length = attention_mask.size(-1)
+ img_length = max(num_tokens_for_output_images)
+ for mask in attention_mask:
+ temp_l = torch.sum(mask)
+ temp_position = [0] * (text_length - temp_l) + list(
+ range(temp_l + img_length + 1)
+ ) # we add a time embedding into the sequence, so add one more token
+ position_ids.append(temp_position)
+ return torch.LongTensor(position_ids)
+
+ def create_mask(self, attention_mask, num_tokens_for_output_images):
+ """
+ OmniGen applies causal attention to each element in the sequence, but applies bidirectional attention within
+ each image sequence References: [OmniGen](https://arxiv.org/pdf/2409.11340)
+ """
+ extended_mask = []
+ padding_images = []
+ text_length = attention_mask.size(-1)
+ img_length = max(num_tokens_for_output_images)
+ seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
+ inx = 0
+ for mask in attention_mask:
+ temp_l = torch.sum(mask)
+ pad_l = text_length - temp_l
+
+ temp_mask = torch.tril(torch.ones(size=(temp_l + 1, temp_l + 1)))
+
+ image_mask = torch.zeros(size=(temp_l + 1, img_length))
+ temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
+
+ image_mask = torch.ones(size=(img_length, temp_l + img_length + 1))
+ temp_mask = torch.cat([temp_mask, image_mask], dim=0)
+
+ if pad_l > 0:
+ pad_mask = torch.zeros(size=(temp_l + 1 + img_length, pad_l))
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
+
+ pad_mask = torch.ones(size=(pad_l, seq_len))
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
+
+ true_img_length = num_tokens_for_output_images[inx]
+ pad_img_length = img_length - true_img_length
+ if pad_img_length > 0:
+ temp_mask[:, -pad_img_length:] = 0
+ temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
+ else:
+ temp_padding_imgs = None
+
+ extended_mask.append(temp_mask.unsqueeze(0))
+ padding_images.append(temp_padding_imgs)
+ inx += 1
+ return torch.cat(extended_mask, dim=0), padding_images
+
+ def adjust_attention_for_input_images(self, attention_mask, image_sizes):
+ for b_inx in image_sizes.keys():
+ for start_inx, end_inx in image_sizes[b_inx]:
+ attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
+
+ return attention_mask
+
+ def pad_input_ids(self, input_ids, image_sizes):
+ max_l = max([len(x) for x in input_ids])
+ padded_ids = []
+ attention_mask = []
+
+ for i in range(len(input_ids)):
+ temp_ids = input_ids[i]
+ temp_l = len(temp_ids)
+ pad_l = max_l - temp_l
+ if pad_l == 0:
+ attention_mask.append([1] * max_l)
+ padded_ids.append(temp_ids)
+ else:
+ attention_mask.append([0] * pad_l + [1] * temp_l)
+ padded_ids.append([self.pad_token_id] * pad_l + temp_ids)
+
+ if i in image_sizes:
+ new_inx = []
+ for old_inx in image_sizes[i]:
+ new_inx.append([x + pad_l for x in old_inx])
+ image_sizes[i] = new_inx
+
+ return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
+
+ def process_mllm_input(self, mllm_inputs, target_img_size):
+ num_tokens_for_output_images = []
+ for img_size in target_img_size:
+ num_tokens_for_output_images.append(img_size[0] * img_size[1] // 16 // 16)
+
+ pixel_values, image_sizes = [], {}
+ b_inx = 0
+ for x in mllm_inputs:
+ if x["pixel_values"] is not None:
+ pixel_values.extend(x["pixel_values"])
+ for size in x["image_sizes"]:
+ if b_inx not in image_sizes:
+ image_sizes[b_inx] = [size]
+ else:
+ image_sizes[b_inx].append(size)
+ b_inx += 1
+ pixel_values = [x.unsqueeze(0) for x in pixel_values]
+
+ input_ids = [x["input_ids"] for x in mllm_inputs]
+ padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
+ position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
+ attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
+ attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
+
+ return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
+
+ def __call__(self, features):
+ mllm_inputs = [f[0] for f in features]
+ cfg_mllm_inputs = [f[1] for f in features]
+ img_cfg_mllm_input = [f[2] for f in features]
+ target_img_size = [f[3] for f in features]
+
+ if img_cfg_mllm_input[0] is not None:
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
+ target_img_size = target_img_size + target_img_size + target_img_size
+ else:
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs
+ target_img_size = target_img_size + target_img_size
+
+ (
+ all_padded_input_ids,
+ all_position_ids,
+ all_attention_mask,
+ all_padding_images,
+ all_pixel_values,
+ all_image_sizes,
+ ) = self.process_mllm_input(mllm_inputs, target_img_size)
+
+ data = {
+ "input_ids": all_padded_input_ids,
+ "attention_mask": all_attention_mask,
+ "position_ids": all_position_ids,
+ "input_pixel_values": all_pixel_values,
+ "input_image_sizes": all_image_sizes,
+ }
+ return data
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 6a1978944c9f..671ab63c9ef3 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -621,6 +621,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class OmniGenTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index b899915c3046..29ebd554223c 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -1217,6 +1217,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class OmniGenPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class PaintByExamplePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/models/transformers/test_models_transformer_omnigen.py b/tests/models/transformers/test_models_transformer_omnigen.py
new file mode 100644
index 000000000000..a7653f1f9d6d
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_omnigen.py
@@ -0,0 +1,88 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import OmniGenTransformer2DModel
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = OmniGenTransformer2DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ height = 8
+ width = 8
+ sequence_length = 24
+
+ hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
+ timestep = torch.rand(size=(batch_size,), dtype=hidden_states.dtype).to(torch_device)
+ input_ids = torch.randint(0, 10, (batch_size, sequence_length)).to(torch_device)
+ input_img_latents = [torch.randn((1, num_channels, height, width)).to(torch_device)]
+ input_image_sizes = {0: [[0, 0 + height * width // 2 // 2]]}
+
+ attn_seq_length = sequence_length + 1 + height * width // 2 // 2
+ attention_mask = torch.ones((batch_size, attn_seq_length, attn_seq_length)).to(torch_device)
+ position_ids = torch.LongTensor([list(range(attn_seq_length))] * batch_size).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "input_ids": input_ids,
+ "input_img_latents": input_img_latents,
+ "input_image_sizes": input_image_sizes,
+ "attention_mask": attention_mask,
+ "position_ids": position_ids,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 8, 8)
+
+ @property
+ def output_shape(self):
+ return (4, 8, 8)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "hidden_size": 16,
+ "num_attention_heads": 4,
+ "num_key_value_heads": 4,
+ "intermediate_size": 32,
+ "num_layers": 1,
+ "pad_token_id": 0,
+ "vocab_size": 100,
+ "in_channels": 4,
+ "time_step_dim": 4,
+ "rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"OmniGenTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/pipelines/omnigen/__init__.py b/tests/pipelines/omnigen/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py
new file mode 100644
index 000000000000..dd5e5fcb2918
--- /dev/null
+++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py
@@ -0,0 +1,153 @@
+import gc
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer
+
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
+from diffusers.utils.testing_utils import (
+ numpy_cosine_similarity_distance,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = OmniGenPipeline
+ params = frozenset(
+ [
+ "prompt",
+ "guidance_scale",
+ ]
+ )
+ batch_params = frozenset(
+ [
+ "prompt",
+ ]
+ )
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+
+ transformer = OmniGenTransformer2DModel(
+ hidden_size=16,
+ num_attention_heads=4,
+ num_key_value_heads=4,
+ intermediate_size=32,
+ num_layers=1,
+ in_channels=4,
+ time_step_dim=4,
+ rope_scaling={"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4, 4, 4, 4),
+ layers_per_block=1,
+ latent_channels=4,
+ norm_num_groups=1,
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 1,
+ "guidance_scale": 3.0,
+ "output_type": "np",
+ "height": 16,
+ "width": 16,
+ }
+ return inputs
+
+ def test_inference(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ generated_image = pipe(**inputs).images[0]
+
+ self.assertEqual(generated_image.shape, (16, 16, 3))
+
+
+@slow
+@require_torch_gpu
+class OmniGenPipelineSlowTests(unittest.TestCase):
+ pipeline_class = OmniGenPipeline
+ repo_id = "shitao/OmniGen-v1-diffusers"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ return {
+ "prompt": "A photo of a cat",
+ "num_inference_steps": 2,
+ "guidance_scale": 2.5,
+ "output_type": "np",
+ "generator": generator,
+ }
+
+ def test_omnigen_inference(self):
+ pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16)
+ pipe.enable_model_cpu_offload()
+
+ inputs = self.get_inputs(torch_device)
+
+ image = pipe(**inputs).images[0]
+ image_slice = image[0, :10, :10]
+
+ expected_slice = np.array(
+ [
+ [0.1783447, 0.16772744, 0.14339337],
+ [0.17066911, 0.15521264, 0.13757327],
+ [0.17072496, 0.15531206, 0.13524258],
+ [0.16746324, 0.1564025, 0.13794944],
+ [0.16490817, 0.15258026, 0.13697758],
+ [0.16971767, 0.15826806, 0.13928896],
+ [0.16782972, 0.15547255, 0.13783783],
+ [0.16464645, 0.15281534, 0.13522372],
+ [0.16535294, 0.15301755, 0.13526791],
+ [0.16365296, 0.15092957, 0.13443318],
+ ],
+ dtype=np.float32,
+ )
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
+
+ assert max_diff < 1e-4