Skip to content

Commit d90cd36

Browse files
zRzRzRzRzRzRzROleehyOa-r-r-o-wyiyixuxu
authored
CogView4 (supports different length c and uc) (#10649)
* init * encode with glm * draft schedule * feat(scheduler): Add CogView scheduler implementation * feat(embeddings): add CogView 2D rotary positional embedding * 1 * Update pipeline_cogview4.py * fix the timestep init and sigma * update latent * draft patch(not work) * fix * [WIP][cogview4]: implement initial CogView4 pipeline Implement the basic CogView4 pipeline structure with the following changes: - Add CogView4 pipeline implementation - Implement DDIM scheduler for CogView4 - Add CogView3Plus transformer architecture - Update embedding models Current limitations: - CFG implementation uses padding for sequence length alignment - Need to verify transformer inference alignment with Megatron TODO: - Consider separate forward passes for condition/uncondition instead of padding approach * [WIP][cogview4][refactor]: Split condition/uncondition forward pass in CogView4 pipeline Split the forward pass for conditional and unconditional predictions in the CogView4 pipeline to match the original implementation. The noise prediction is now done separately for each case before combining them for guidance. However, the results still need improvement. This is a work in progress as the generated images are not yet matching expected quality. * use with -2 hidden state * remove text_projector * 1 * [WIP] Add tensor-reload to align input from transformer block * [WIP] for older glm * use with cogview4 transformers forward twice of u and uc * Update convert_cogview4_to_diffusers.py * remove this * use main example * change back * reset * setback * back * back 4 * Fix qkv conversion logic for CogView4 to Diffusers format * back5 * revert to sat to cogview4 version * update a new convert from megatron * [WIP][cogview4]: implement CogView4 attention processor Add CogView4AttnProcessor class for implementing scaled dot-product attention with rotary embeddings for the CogVideoX model. This processor concatenates encoder and hidden states, applies QKV projections and RoPE, but does not include spatial normalization. TODO: - Fix incorrect QKV projection weights - Resolve ~25% error in RoPE implementation compared to Megatron * [cogview4] implement CogView4 transformer block Implement CogView4 transformer block following the Megatron architecture: - Add multi-modulate and multi-gate mechanisms for adaptive layer normalization - Implement dual-stream attention with encoder-decoder structure - Add feed-forward network with GELU activation - Support rotary position embeddings for image tokens The implementation follows the original CogView4 architecture while adapting it to work within the diffusers framework. * with new attn * [bugfix] fix dimension mismatch in CogView4 attention * [cogview4][WIP]: update final normalization in CogView4 transformer Refactored the final normalization layer in CogView4 transformer to use separate layernorm and AdaLN operations instead of combined AdaLayerNormContinuous. This matches the original implementation but needs validation. Needs verification against reference implementation. * 1 * put back * Update transformer_cogview4.py * change time_shift * Update pipeline_cogview4.py * change timesteps * fix * change text_encoder_id * [cogview4][rope] align RoPE implementation with Megatron - Implement apply_rope method in attention processor to match Megatron's implementation - Update position embeddings to ensure compatibility with Megatron-style rotary embeddings - Ensure consistent rotary position encoding across attention layers This change improves compatibility with Megatron-based models and provides better alignment with the original implementation's positional encoding approach. * [cogview4][bugfix] apply silu activation to time embeddings in CogView4 Applied silu activation to time embeddings before splitting into conditional and unconditional parts in CogView4Transformer2DModel. This matches the original implementation and helps ensure correct time conditioning behavior. * [cogview4][chore] clean up pipeline code - Remove commented out code and debug statements - Remove unused retrieve_timesteps function - Clean up code formatting and documentation This commit focuses on code cleanup in the CogView4 pipeline implementation, removing unnecessary commented code and improving readability without changing functionality. * [cogview4][scheduler] Implement CogView4 scheduler and pipeline * now It work * add timestep * batch * change convert scipt * refactor pt. 1; make style * refactor pt. 2 * refactor pt. 3 * add tests * make fix-copies * update toctree.yml * use flow match scheduler instead of custom * remove scheduling_cogview.py * add tiktoken to test dependencies * Update src/diffusers/models/embeddings.py Co-authored-by: YiYi Xu <[email protected]> * apply suggestions from review * use diffusers apply_rotary_emb * update flow match scheduler to accept timesteps * fix comment * apply review sugestions * Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py Co-authored-by: YiYi Xu <[email protected]> --------- Co-authored-by: 三洋三洋 <[email protected]> Co-authored-by: OleehyO <[email protected]> Co-authored-by: Aryan <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 69f919d commit d90cd36

24 files changed

+2262
-18
lines changed

docs/source/en/_toctree.yml

+4
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,8 @@
278278
title: ConsisIDTransformer3DModel
279279
- local: api/models/cogview3plus_transformer2d
280280
title: CogView3PlusTransformer2DModel
281+
- local: api/models/cogview4_transformer2d
282+
title: CogView4Transformer2DModel
281283
- local: api/models/dit_transformer2d
282284
title: DiTTransformer2DModel
283285
- local: api/models/flux_transformer
@@ -382,6 +384,8 @@
382384
title: CogVideoX
383385
- local: api/pipelines/cogview3
384386
title: CogView3
387+
- local: api/pipelines/cogview4
388+
title: CogView4
385389
- local: api/pipelines/consisid
386390
title: ConsisID
387391
- local: api/pipelines/consistency_models
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License. -->
11+
12+
# CogView4Transformer2DModel
13+
14+
A Diffusion Transformer model for 2D data from [CogView4]()
15+
16+
The model can be loaded with the following code snippet.
17+
18+
```python
19+
from diffusers import CogView4Transformer2DModel
20+
21+
transformer = CogView4Transformer2DModel.from_pretrained("THUDM/CogView4-6B", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
22+
```
23+
24+
## CogView4Transformer2DModel
25+
26+
[[autodoc]] CogView4Transformer2DModel
27+
28+
## Transformer2DModelOutput
29+
30+
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
-->
15+
16+
# CogView4
17+
18+
<Tip>
19+
20+
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
21+
22+
</Tip>
23+
24+
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
25+
26+
## CogView4Pipeline
27+
28+
[[autodoc]] CogView4Pipeline
29+
- all
30+
- __call__
31+
32+
## CogView4PipelineOutput
33+
34+
[[autodoc]] pipelines.cogview4.pipeline_output.CogView4PipelineOutput
+243
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
"""
2+
Convert a CogView4 checkpoint from SAT(https://github.com/THUDM/SwissArmyTransformer) to the Diffusers format.
3+
(deprecated Since 2025-02-07 and will remove it in later CogView4 version)
4+
5+
This script converts a CogView4 checkpoint to the Diffusers format, which can then be used
6+
with the Diffusers library.
7+
8+
Example usage:
9+
python scripts/convert_cogview4_to_diffusers.py \
10+
--transformer_checkpoint_path 'your path/cogview4_6b/1/mp_rank_00_model_states.pt' \
11+
--vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \
12+
--output_path "THUDM/CogView4-6B" \
13+
--dtype "bf16"
14+
15+
Arguments:
16+
--transformer_checkpoint_path: Path to Transformer state dict.
17+
--vae_checkpoint_path: Path to VAE state dict.
18+
--output_path: The path to save the converted model.
19+
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
20+
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
21+
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
22+
23+
Default is "bf16" because CogView4 uses bfloat16 for Training.
24+
25+
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
26+
"""
27+
28+
import argparse
29+
from contextlib import nullcontext
30+
31+
import torch
32+
from accelerate import init_empty_weights
33+
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
34+
35+
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
36+
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
37+
from diffusers.utils.import_utils import is_accelerate_available
38+
39+
40+
CTX = init_empty_weights if is_accelerate_available() else nullcontext
41+
42+
parser = argparse.ArgumentParser()
43+
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
44+
parser.add_argument("--vae_checkpoint_path", default=None, type=str)
45+
parser.add_argument("--output_path", required=True, type=str)
46+
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
47+
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
48+
parser.add_argument("--dtype", type=str, default="bf16")
49+
50+
args = parser.parse_args()
51+
52+
53+
# this is specific to `AdaLayerNormContinuous`:
54+
# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
55+
def swap_scale_shift(weight, dim):
56+
shift, scale = weight.chunk(2, dim=0)
57+
new_weight = torch.cat([scale, shift], dim=0)
58+
return new_weight
59+
60+
61+
def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path):
62+
original_state_dict = torch.load(ckpt_path, map_location="cpu")
63+
original_state_dict = original_state_dict["module"]
64+
original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}
65+
66+
new_state_dict = {}
67+
68+
# Convert patch_embed
69+
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
70+
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
71+
new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
72+
new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
73+
74+
# Convert time_condition_embed
75+
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
76+
"time_embed.0.weight"
77+
)
78+
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
79+
"time_embed.0.bias"
80+
)
81+
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
82+
"time_embed.2.weight"
83+
)
84+
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
85+
"time_embed.2.bias"
86+
)
87+
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
88+
"label_emb.0.0.weight"
89+
)
90+
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
91+
"label_emb.0.0.bias"
92+
)
93+
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
94+
"label_emb.0.2.weight"
95+
)
96+
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
97+
"label_emb.0.2.bias"
98+
)
99+
100+
# Convert transformer blocks, for cogview4 is 28 blocks
101+
for i in range(28):
102+
block_prefix = f"transformer_blocks.{i}."
103+
old_prefix = f"transformer.layers.{i}."
104+
adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
105+
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
106+
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
107+
108+
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
109+
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
110+
q, k, v = qkv_weight.chunk(3, dim=0)
111+
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
112+
113+
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
114+
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
115+
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
116+
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
117+
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
118+
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
119+
120+
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
121+
old_prefix + "attention.dense.weight"
122+
)
123+
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
124+
old_prefix + "attention.dense.bias"
125+
)
126+
127+
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
128+
old_prefix + "mlp.dense_h_to_4h.weight"
129+
)
130+
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
131+
old_prefix + "mlp.dense_h_to_4h.bias"
132+
)
133+
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
134+
old_prefix + "mlp.dense_4h_to_h.weight"
135+
)
136+
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
137+
138+
# Convert final norm and projection
139+
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
140+
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
141+
)
142+
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
143+
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
144+
)
145+
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
146+
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
147+
148+
return new_state_dict
149+
150+
151+
def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
152+
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
153+
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
154+
155+
156+
def main(args):
157+
if args.dtype == "fp16":
158+
dtype = torch.float16
159+
elif args.dtype == "bf16":
160+
dtype = torch.bfloat16
161+
elif args.dtype == "fp32":
162+
dtype = torch.float32
163+
else:
164+
raise ValueError(f"Unsupported dtype: {args.dtype}")
165+
166+
transformer = None
167+
vae = None
168+
169+
if args.transformer_checkpoint_path is not None:
170+
converted_transformer_state_dict = convert_cogview4_transformer_checkpoint_to_diffusers(
171+
args.transformer_checkpoint_path
172+
)
173+
transformer = CogView4Transformer2DModel(
174+
patch_size=2,
175+
in_channels=16,
176+
num_layers=28,
177+
attention_head_dim=128,
178+
num_attention_heads=32,
179+
out_channels=16,
180+
text_embed_dim=4096,
181+
time_embed_dim=512,
182+
condition_dim=256,
183+
pos_embed_max_size=128,
184+
)
185+
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
186+
if dtype is not None:
187+
# Original checkpoint data type will be preserved
188+
transformer = transformer.to(dtype=dtype)
189+
190+
if args.vae_checkpoint_path is not None:
191+
vae_config = {
192+
"in_channels": 3,
193+
"out_channels": 3,
194+
"down_block_types": ("DownEncoderBlock2D",) * 4,
195+
"up_block_types": ("UpDecoderBlock2D",) * 4,
196+
"block_out_channels": (128, 512, 1024, 1024),
197+
"layers_per_block": 3,
198+
"act_fn": "silu",
199+
"latent_channels": 16,
200+
"norm_num_groups": 32,
201+
"sample_size": 1024,
202+
"scaling_factor": 1.0,
203+
"force_upcast": True,
204+
"use_quant_conv": False,
205+
"use_post_quant_conv": False,
206+
"mid_block_add_attention": False,
207+
}
208+
converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
209+
vae = AutoencoderKL(**vae_config)
210+
vae.load_state_dict(converted_vae_state_dict, strict=True)
211+
if dtype is not None:
212+
vae = vae.to(dtype=dtype)
213+
214+
text_encoder_id = "THUDM/glm-4-9b-hf"
215+
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
216+
text_encoder = GlmForCausalLM.from_pretrained(
217+
text_encoder_id,
218+
cache_dir=args.text_encoder_cache_dir,
219+
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
220+
)
221+
222+
for param in text_encoder.parameters():
223+
param.data = param.data.contiguous()
224+
225+
scheduler = FlowMatchEulerDiscreteScheduler(
226+
base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear"
227+
)
228+
229+
pipe = CogView4Pipeline(
230+
tokenizer=tokenizer,
231+
text_encoder=text_encoder,
232+
vae=vae,
233+
transformer=transformer,
234+
scheduler=scheduler,
235+
)
236+
237+
# This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
238+
# save some memory used for model loading.
239+
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
240+
241+
242+
if __name__ == "__main__":
243+
main(args)

0 commit comments

Comments
 (0)