Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cosmos #10660

Open
wants to merge 48 commits into
base: main
Choose a base branch
from
Open

Cosmos #10660

Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
65decb6
begin transformer conversion
a-r-r-o-w Jan 27, 2025
ed4527f
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Feb 1, 2025
a282f47
refactor
a-r-r-o-w Feb 2, 2025
2753089
refactor
a-r-r-o-w Feb 2, 2025
b23ac33
refactor
a-r-r-o-w Feb 2, 2025
62f6369
refactor
a-r-r-o-w Feb 3, 2025
3d2c5ee
refactor
a-r-r-o-w Feb 3, 2025
6eb43df
refactor
a-r-r-o-w Feb 3, 2025
969dd17
update
a-r-r-o-w Feb 3, 2025
88faab1
add conversion script
a-r-r-o-w Feb 3, 2025
a63e543
add pipeline
a-r-r-o-w Feb 3, 2025
4f1161d
make fix-copies
a-r-r-o-w Feb 3, 2025
e4173df
remove einops
a-r-r-o-w Feb 3, 2025
6d6c10c
update docs
a-r-r-o-w Feb 3, 2025
c5bd5a3
gradient checkpointing
a-r-r-o-w Feb 3, 2025
f9fc67c
add transformer test
a-r-r-o-w Feb 3, 2025
89906c2
update
a-r-r-o-w Feb 5, 2025
98f1ce7
debug
a-r-r-o-w Feb 5, 2025
9a7f479
remove prints
a-r-r-o-w Feb 5, 2025
475ad31
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Feb 18, 2025
9df2e7e
match sigmas
a-r-r-o-w Feb 18, 2025
cedcab1
add vae pt. 1
a-r-r-o-w Feb 25, 2025
2dda910
finish CV* vae
a-r-r-o-w Feb 25, 2025
de925be
update
a-r-r-o-w Feb 25, 2025
59d7793
update
a-r-r-o-w Feb 26, 2025
1203f44
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Mar 10, 2025
b9a5255
update
a-r-r-o-w Mar 10, 2025
75f3f45
update
a-r-r-o-w Mar 10, 2025
10289f7
update
a-r-r-o-w Mar 10, 2025
547d68f
update
a-r-r-o-w Mar 11, 2025
13cd8cd
make fix-copies
a-r-r-o-w Mar 11, 2025
6f8495b
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Mar 11, 2025
15c8020
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Mar 11, 2025
9ee31fb
update
a-r-r-o-w Mar 11, 2025
7c54eb1
make fix-copies
a-r-r-o-w Mar 11, 2025
bf9190f
fix
a-r-r-o-w Mar 12, 2025
64fc4fe
update
a-r-r-o-w Mar 12, 2025
a592f74
update
a-r-r-o-w Mar 12, 2025
22ea3ca
make fix-copies
a-r-r-o-w Mar 12, 2025
e897d0c
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Mar 20, 2025
cd712f0
update
a-r-r-o-w Mar 21, 2025
8c188ec
update tests
a-r-r-o-w Mar 21, 2025
ebea597
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Mar 21, 2025
7799728
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Mar 26, 2025
2c2b658
handle device and dtype for safety checker; required in latest diffusers
a-r-r-o-w Mar 26, 2025
0c3f56f
remove enable_gqa and use repeat_interleave instead
a-r-r-o-w Apr 5, 2025
b909f7e
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Apr 5, 2025
06373f1
Merge branch 'main' into integrations/cosmos
a-r-r-o-w Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@
title: CogView3PlusTransformer2DModel
- local: api/models/cogview4_transformer2d
title: CogView4Transformer2DModel
- local: api/models/cosmos_transformer3d
title: CosmosTransformer3DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/easyanimate_transformer3d
Expand Down Expand Up @@ -424,6 +426,8 @@
title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/controlnet_union
title: ControlNetUnion
- local: api/pipelines/cosmos
title: Cosmos
- local: api/pipelines/dance_diffusion
title: Dance Diffusion
- local: api/pipelines/ddim
Expand Down
30 changes: 30 additions & 0 deletions docs/source/en/api/models/cosmos_transformer3d.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->

# CosmosTransformer3DModel

A Diffusion Transformer model for 3D video-like data was introduced in [Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.

The model can be loaded with the following code snippet.

```python
from diffusers import CosmosTransformer3DModel

transformer = CosmosTransformer3DModel.from_pretrained("nvidia/Cosmos-1.0-Diffusion-7B-Text2World", subfolder="transformer", torch_dtype=torch.bfloat16)
```

## CosmosTransformer3DModel

[[autodoc]] CosmosTransformer3DModel

## Transformer2DModelOutput

[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
35 changes: 35 additions & 0 deletions docs/source/en/api/pipelines/cosmos.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->

# Cosmos

[Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.

*Physical AI needs to be trained digitally first. It needs a digital twin of itself, the policy model, and a digital twin of the world, the world model. In this paper, we present the Cosmos World Foundation Model Platform to help developers build customized world models for their Physical AI setups. We position a world foundation model as a general-purpose world model that can be fine-tuned into customized world models for downstream applications. Our platform covers a video curation pipeline, pre-trained world foundation models, examples of post-training of pre-trained world foundation models, and video tokenizers. To help Physical AI builders solve the most critical problems of our society, we make our platform open-source and our models open-weight with permissive licenses available via https://github.com/NVIDIA/Cosmos.*

<Tip>

Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.

</Tip>

## CosmosPipeline

[[autodoc]] CosmosPipeline
- all
- __call__

## CosmosPipelineOutput

[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
287 changes: 287 additions & 0 deletions scripts/convert_cosmos_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
import argparse
import pathlib
from typing import Any, Dict

import torch
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from transformers import T5EncoderModel, T5TokenizerFast

from diffusers import AutoencoderKLCosmos, CosmosPipeline, CosmosTransformer3DModel, EDMEulerScheduler


def remove_keys_(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)


def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)


def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
block_index = int(key.split(".")[1].removeprefix("block"))
new_key = key

old_prefix = f"blocks.block{block_index}"
new_prefix = f"transformer_blocks.{block_index}"
new_key = new_prefix + new_key.removeprefix(old_prefix)

state_dict[new_key] = state_dict.pop(key)


TRANSFORMER_KEYS_RENAME_DICT = {
"t_embedder.1": "time_embed.t_embedder",
"affline_norm": "time_embed.norm",
".blocks.0.block.attn": ".attn1",
".blocks.1.block.attn": ".attn2",
".blocks.2.block": ".ff",
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
"to_q.0": "to_q",
"to_q.1": "norm_q",
"to_k.0": "to_k",
"to_k.1": "norm_k",
"to_v.0": "to_v",
"layer1": "net.0.proj",
"layer2": "net.2",
"proj.1": "proj",
"x_embedder": "patch_embed",
"extra_pos_embedder": "learnable_pos_embed",
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {
"blocks.block": rename_transformer_blocks_,
"logvar.0.freqs": remove_keys_,
"logvar.0.phases": remove_keys_,
"logvar.1.weight": remove_keys_,
"pos_embedder.seq": remove_keys_,
}

VAE_KEYS_RENAME_DICT = {
"down.0": "down_blocks.0",
"down.1": "down_blocks.1",
"down.2": "down_blocks.2",
"up.0": "up_blocks.2",
"up.1": "up_blocks.1",
"up.2": "up_blocks.0",
".block.": ".resnets.",
"downsample": "downsamplers.0",
"upsample": "upsamplers.0",
"mid.block_1": "mid_block.resnets.0",
"mid.attn_1.0": "mid_block.attentions.0",
"mid.attn_1.1": "mid_block.temp_attentions.0",
"mid.block_2": "mid_block.resnets.1",
".q.conv3d": ".to_q",
".k.conv3d": ".to_k",
".v.conv3d": ".to_v",
".proj_out.conv3d": ".to_out.0",
".0.conv3d": ".conv_s",
".1.conv3d": ".conv_t",
"conv1.conv3d": "conv1",
"conv2.conv3d": "conv2",
"conv3.conv3d": "conv3",
"nin_shortcut.conv3d": "conv_shortcut",
"quant_conv.conv3d": "quant_conv",
"post_quant_conv.conv3d": "post_quant_conv",
}

VAE_SPECIAL_KEYS_REMAP = {
"wavelets": remove_keys_,
"_arange": remove_keys_,
"patch_size_buffer": remove_keys_,
}

VAE_CONFIGS = {
"CV8x8x8-0.1": {
"name": "nvidia/Cosmos-0.1-Tokenizer-CV8x8x8",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 16,
"encoder_block_out_channels": (128, 256, 512, 512),
"decode_block_out_channels": (256, 512, 512, 512),
"attention_resolutions": (32,),
"resolution": 1024,
"num_layers": 2,
"patch_size": 4,
"patch_type": "haar",
"scaling_factor": 1.0,
"spatial_compression_ratio": 8,
"temporal_compression_ratio": 8,
"latents_mean": None,
"latents_std": None,
},
},
"CV8x8x8-1.0": {
"name": "nvidia/Cosmos-1.0-Tokenizer-CV8x8x8",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 16,
"encoder_block_out_channels": (128, 256, 512, 512),
"decode_block_out_channels": (256, 512, 512, 512),
"attention_resolutions": (32,),
"resolution": 1024,
"num_layers": 2,
"patch_size": 4,
"patch_type": "haar",
"scaling_factor": 1.0,
"spatial_compression_ratio": 8,
"temporal_compression_ratio": 8,
"latents_mean": None,
"latents_std": None,
},
},
}


def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict


def convert_transformer(ckpt_path: str):
PREFIX_KEY = "net."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))

with init_empty_weights():
transformer = CosmosTransformer3DModel()

for key in list(original_state_dict.keys()):
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer


def convert_vae(vae_type: str):
model_name = VAE_CONFIGS[vae_type]["name"]
snapshot_directory = snapshot_download(model_name, repo_type="model")
directory = pathlib.Path(snapshot_directory)

autoencoder_file = directory / "autoencoder.jit"
mean_std_file = directory / "mean_std.pt"

original_state_dict = torch.jit.load(autoencoder_file.as_posix()).state_dict()
if mean_std_file.exists():
mean_std = torch.load(mean_std_file, map_location="cpu", weights_only=True)
else:
mean_std = (None, None)

config = VAE_CONFIGS[vae_type]["diffusers_config"]
config.update(
{
"latents_mean": mean_std[0].detach().cpu().numpy().tolist(),
"latents_std": mean_std[1].detach().cpu().numpy().tolist(),
}
)
vae = AutoencoderKLCosmos(**config)

for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)

vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_type", type=str, default=None, choices=list(VAE_CONFIGS.keys()), help="Type of VAE")
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
return parser.parse_args()


DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}


if __name__ == "__main__":
args = get_args()

transformer = None
dtype = DTYPE_MAPPING[args.dtype]

if args.save_pipeline:
assert args.transformer_ckpt_path is not None
assert args.vae_type is not None
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None

if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

if args.vae_type is not None:
vae = convert_vae(args.vae_type)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")

if args.save_pipeline:
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
# So, the sigma_min values that is used is the default value of 0.002.
scheduler = EDMEulerScheduler(
sigma_min=0.002,
sigma_max=80,
sigma_data=0.5,
sigma_schedule="karras",
num_train_timesteps=1000,
prediction_type="epsilon",
rho=7.0,
final_sigmas_type="sigma_min",
)

pipe = CosmosPipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
Loading
Loading