Skip to content

Wan VACE #11582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Wan VACE #11582

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion docs/source/en/api/pipelines/wan.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,16 @@

[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.

<!-- TODO(aryan): update abstract once paper is out -->
*This report presents Wan, a comprehensive and open suite of video foundation models designed to push the boundaries of video generation. Built upon the mainstream diffusion transformer paradigm, Wan achieves significant advancements in generative capabilities through a series of innovations, including our novel VAE, scalable pre-training strategies, large-scale data curation, and automated evaluation metrics. These contributions collectively enhance the model's performance and versatility. Specifically, Wan is characterized by four key features: Leading Performance: The 14B model of Wan, trained on a vast dataset comprising billions of images and videos, demonstrates the scaling laws of video generation with respect to both data and model size. It consistently outperforms the existing open-source models as well as state-of-the-art commercial solutions across multiple internal and external benchmarks, demonstrating a clear and significant performance superiority. Comprehensiveness: Wan offers two capable models, i.e., 1.3B and 14B parameters, for efficiency and effectiveness respectively. It also covers multiple downstream applications, including image-to-video, instruction-guided video editing, and personal video generation, encompassing up to eight tasks. Consumer-Grade Efficiency: The 1.3B model demonstrates exceptional resource efficiency, requiring only 8.19 GB VRAM, making it compatible with a wide range of consumer-grade GPUs. Openness: We open-source the entire series of Wan, including source code and all models, with the goal of fostering the growth of the video generation community. This openness seeks to significantly expand the creative possibilities of video production in the industry and provide academia with high-quality video foundation models. All the code and models are available at [this https URL](https://github.com/Wan-Video/Wan2.1).*

The following Wan models are supported in Diffusers:
- [Wan 2.1 T2V 1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)
- [Wan 2.1 T2V 14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers)
- [Wan 2.1 I2V 14B - 480P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers)
- [Wan 2.1 I2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers)
- [Wan 2.1 FLF2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers)
- [Wan 2.1 VACE 1.3B](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) (Unofficial diffusers checkpoint for now available [here](https://huggingface.co/a-r-r-o-w/Wan-VACE-1.3B-diffusers))
- [Wan 2.1 VACE 14B](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) (Unofficial diffusers checkpoint for now available [here](https://huggingface.co/linoyts/Wan-VACE-14B-diffusers))

## Generating Videos with Wan 2.1

Expand Down Expand Up @@ -227,6 +236,19 @@ output = pipe(
export_to_video(output, "wan-v2v.mp4", fps=16)
```

### Any-to-Video Controllable Generation

Wan VACE supports various generation techniques which achieve controllable video generation. Some of the capabilities include:
- Control to Video (Depth, Pose, Sketch, Flow, Grayscale, Scribble, Layout, Boundary Box, etc.). Recommended library for preprocessing videos to obtain control videos: [huggingface/controlnet_aux]()
- Image/Video to Video (first frame, last frame, starting clip, ending clip, random clips)
- Inpainting and Outpainting
- Subject to Video (faces, object, characters, etc.)
- Composition to Video (reference anything, animate anything, swap anything, expand anything, move anything, etc.)

The code snippets available in [this](https://github.com/huggingface/diffusers/pull/11582) pull request demonstrate some examples of how videos can be generated with controllability signals.

The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.

## Memory Optimizations for Wan 2.1

Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
Expand Down
143 changes: 130 additions & 13 deletions scripts/convert_wan_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
import pathlib
from typing import Any, Dict
from typing import Any, Dict, Tuple

import torch
from accelerate import init_empty_weights
Expand All @@ -14,6 +14,8 @@
WanImageToVideoPipeline,
WanPipeline,
WanTransformer3DModel,
WanVACEPipeline,
WanVACETransformer3DModel,
)


Expand Down Expand Up @@ -59,7 +61,52 @@
"attn2.norm_k_img": "attn2.norm_added_k",
}

VACE_TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"head.modulation": "scale_shift_table",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
# # For the I2V model
# "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
# "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
# "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
# "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# # for the FLF2V model
# "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
# Add attention component mappings
"self_attn.q": "attn1.to_q",
"self_attn.k": "attn1.to_k",
"self_attn.v": "attn1.to_v",
"self_attn.o": "attn1.to_out.0",
"self_attn.norm_q": "attn1.norm_q",
"self_attn.norm_k": "attn1.norm_k",
"cross_attn.q": "attn2.to_q",
"cross_attn.k": "attn2.to_k",
"cross_attn.v": "attn2.to_v",
"cross_attn.o": "attn2.to_out.0",
"cross_attn.norm_q": "attn2.norm_q",
"cross_attn.norm_k": "attn2.norm_k",
"attn2.to_k_img": "attn2.add_k_proj",
"attn2.to_v_img": "attn2.add_v_proj",
"attn2.norm_k_img": "attn2.norm_added_k",
"before_proj": "proj_in",
"after_proj": "proj_out",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {}
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}


def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
Expand All @@ -74,7 +121,7 @@ def load_sharded_safetensors(dir: pathlib.Path):
return state_dict


def get_transformer_config(model_type: str) -> Dict[str, Any]:
def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
if model_type == "Wan-T2V-1.3B":
config = {
"model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
Expand All @@ -94,6 +141,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096,
},
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-T2V-14B":
config = {
"model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
Expand All @@ -113,6 +162,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096,
},
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-I2V-14B-480p":
config = {
"model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
Expand All @@ -133,6 +184,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096,
},
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-I2V-14B-720p":
config = {
"model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
Expand All @@ -153,6 +206,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096,
},
}
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-FLF2V-14B-720P":
config = {
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
Expand All @@ -175,28 +230,80 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"pos_embed_seq_len": 257 * 2,
},
}
return config
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-VACE-1.3B":
config = {
"model_id": "Wan-AI/Wan2.1-VACE-1.3B",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 12,
"num_layers": 30,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
"vace_in_channels": 96,
},
}
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-VACE-14B":
config = {
"model_id": "Wan-AI/Wan2.1-VACE-14B",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
"vace_in_channels": 96,
},
}
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, RENAME_DICT, SPECIAL_KEYS_REMAP


def convert_transformer(model_type: str):
config = get_transformer_config(model_type)
config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type)

diffusers_config = config["diffusers_config"]
model_id = config["model_id"]
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))

original_state_dict = load_sharded_safetensors(model_dir)

with init_empty_weights():
transformer = WanTransformer3DModel.from_config(diffusers_config)
if "VACE" not in model_type:
transformer = WanTransformer3DModel.from_config(diffusers_config)
else:
transformer = WanVACETransformer3DModel.from_config(diffusers_config)

for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
for replace_key, rename_key in RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)

for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
Expand Down Expand Up @@ -412,7 +519,7 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--dtype", default="fp32")
parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"])
return parser.parse_args()


Expand All @@ -426,18 +533,20 @@ def get_args():
if __name__ == "__main__":
args = get_args()

transformer = None
dtype = DTYPE_MAPPING[args.dtype]

transformer = convert_transformer(args.model_type).to(dtype=dtype)
transformer = convert_transformer(args.model_type)
vae = convert_vae()
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
scheduler = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
)

# If user has specified "none", we keep the original dtypes of the state dict without any conversion
if args.dtype != "none":
dtype = DTYPE_MAPPING[args.dtype]
transformer.to(dtype)

if "I2V" in args.model_type or "FLF2V" in args.model_type:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
Expand All @@ -452,6 +561,14 @@ def get_args():
image_encoder=image_encoder,
image_processor=image_processor,
)
elif "VACE" in args.model_type:
pipe = WanVACEPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
)
else:
pipe = WanPipeline(
transformer=transformer,
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
"UVit2DModel",
"VQModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
]
)
_import_structure["optimization"] = [
Expand Down Expand Up @@ -527,6 +528,7 @@
"VQDiffusionPipeline",
"WanImageToVideoPipeline",
"WanPipeline",
"WanVACEPipeline",
"WanVideoToVideoPipeline",
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
Expand Down Expand Up @@ -820,6 +822,7 @@
UVit2DModel,
VQModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
)
from .optimization import (
get_constant_schedule,
Expand Down Expand Up @@ -1111,6 +1114,7 @@
VQDiffusionPipeline,
WanImageToVideoPipeline,
WanPipeline,
WanVACEPipeline,
WanVideoToVideoPipeline,
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
}


Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
Expand Down Expand Up @@ -178,6 +179,7 @@
Transformer2DModel,
TransformerTemporalModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
)
from .unets import (
I2VGenXLUNet,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
Loading