diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index aab3d4d130df..7a1088f63521 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -278,6 +278,8 @@ title: ConsisIDTransformer3DModel - local: api/models/cogview3plus_transformer2d title: CogView3PlusTransformer2DModel + - local: api/models/cogview4_transformer2d + title: CogView4Transformer2DModel - local: api/models/dit_transformer2d title: DiTTransformer2DModel - local: api/models/flux_transformer @@ -382,6 +384,8 @@ title: CogVideoX - local: api/pipelines/cogview3 title: CogView3 + - local: api/pipelines/cogview4 + title: CogView4 - local: api/pipelines/consisid title: ConsisID - local: api/pipelines/consistency_models diff --git a/docs/source/en/api/models/cogview4_transformer2d.md b/docs/source/en/api/models/cogview4_transformer2d.md new file mode 100644 index 000000000000..4bf14bdd4991 --- /dev/null +++ b/docs/source/en/api/models/cogview4_transformer2d.md @@ -0,0 +1,30 @@ + + +# CogView4Transformer2DModel + +A Diffusion Transformer model for 2D data from [CogView4]() + +The model can be loaded with the following code snippet. + +```python +from diffusers import CogView4Transformer2DModel + +transformer = CogView4Transformer2DModel.from_pretrained("THUDM/CogView4-6B", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +``` + +## CogView4Transformer2DModel + +[[autodoc]] CogView4Transformer2DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/cogview4.md b/docs/source/en/api/pipelines/cogview4.md new file mode 100644 index 000000000000..cc17c3c905fb --- /dev/null +++ b/docs/source/en/api/pipelines/cogview4.md @@ -0,0 +1,34 @@ + + +# CogView4 + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM). + +## CogView4Pipeline + +[[autodoc]] CogView4Pipeline + - all + - __call__ + +## CogView4PipelineOutput + +[[autodoc]] pipelines.cogview4.pipeline_output.CogView4PipelineOutput diff --git a/scripts/convert_cogview4_to_diffusers.py b/scripts/convert_cogview4_to_diffusers.py new file mode 100644 index 000000000000..484c817dd938 --- /dev/null +++ b/scripts/convert_cogview4_to_diffusers.py @@ -0,0 +1,243 @@ +""" +Convert a CogView4 checkpoint from SAT(https://github.com/THUDM/SwissArmyTransformer) to the Diffusers format. +(deprecated Since 2025-02-07 and will remove it in later CogView4 version) + +This script converts a CogView4 checkpoint to the Diffusers format, which can then be used +with the Diffusers library. + +Example usage: + python scripts/convert_cogview4_to_diffusers.py \ + --transformer_checkpoint_path 'your path/cogview4_6b/1/mp_rank_00_model_states.pt' \ + --vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \ + --output_path "THUDM/CogView4-6B" \ + --dtype "bf16" + +Arguments: + --transformer_checkpoint_path: Path to Transformer state dict. + --vae_checkpoint_path: Path to VAE state dict. + --output_path: The path to save the converted model. + --push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`. + --text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used + --dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered. + + Default is "bf16" because CogView4 uses bfloat16 for Training. + +Note: You must provide either --original_state_dict_repo_id or --checkpoint_path. +""" + +import argparse +from contextlib import nullcontext + +import torch +from accelerate import init_empty_weights +from transformers import GlmForCausalLM, PreTrainedTokenizerFast + +from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler +from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint +from diffusers.utils.import_utils import is_accelerate_available + + +CTX = init_empty_weights if is_accelerate_available() else nullcontext + +parser = argparse.ArgumentParser() +parser.add_argument("--transformer_checkpoint_path", default=None, type=str) +parser.add_argument("--vae_checkpoint_path", default=None, type=str) +parser.add_argument("--output_path", required=True, type=str) +parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving") +parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory") +parser.add_argument("--dtype", type=str, default="bf16") + +args = parser.parse_args() + + +# this is specific to `AdaLayerNormContinuous`: +# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale +def swap_scale_shift(weight, dim): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_cogview4_transformer_checkpoint_to_diffusers(ckpt_path): + original_state_dict = torch.load(ckpt_path, map_location="cpu") + original_state_dict = original_state_dict["module"] + original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()} + + new_state_dict = {} + + # Convert patch_embed + new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight") + new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias") + new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight") + new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias") + + # Convert time_condition_embed + new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( + "time_embed.0.weight" + ) + new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( + "time_embed.0.bias" + ) + new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( + "time_embed.2.weight" + ) + new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( + "time_embed.2.bias" + ) + new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop( + "label_emb.0.0.weight" + ) + new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop( + "label_emb.0.0.bias" + ) + new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop( + "label_emb.0.2.weight" + ) + new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop( + "label_emb.0.2.bias" + ) + + # Convert transformer blocks, for cogview4 is 28 blocks + for i in range(28): + block_prefix = f"transformer_blocks.{i}." + old_prefix = f"transformer.layers.{i}." + adaln_prefix = f"mixins.adaln.adaln_modules.{i}." + new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight") + new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias") + + qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight") + qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias") + q, k, v = qkv_weight.chunk(3, dim=0) + q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0) + + new_state_dict[block_prefix + "attn1.to_q.weight"] = q + new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias + new_state_dict[block_prefix + "attn1.to_k.weight"] = k + new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias + new_state_dict[block_prefix + "attn1.to_v.weight"] = v + new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias + + new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop( + old_prefix + "attention.dense.weight" + ) + new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop( + old_prefix + "attention.dense.bias" + ) + + new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop( + old_prefix + "mlp.dense_h_to_4h.weight" + ) + new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop( + old_prefix + "mlp.dense_h_to_4h.bias" + ) + new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop( + old_prefix + "mlp.dense_4h_to_h.weight" + ) + new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias") + + # Convert final norm and projection + new_state_dict["norm_out.linear.weight"] = swap_scale_shift( + original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0 + ) + new_state_dict["norm_out.linear.bias"] = swap_scale_shift( + original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0 + ) + new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight") + new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias") + + return new_state_dict + + +def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config): + original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + return convert_ldm_vae_checkpoint(original_state_dict, vae_config) + + +def main(args): + if args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + elif args.dtype == "fp32": + dtype = torch.float32 + else: + raise ValueError(f"Unsupported dtype: {args.dtype}") + + transformer = None + vae = None + + if args.transformer_checkpoint_path is not None: + converted_transformer_state_dict = convert_cogview4_transformer_checkpoint_to_diffusers( + args.transformer_checkpoint_path + ) + transformer = CogView4Transformer2DModel( + patch_size=2, + in_channels=16, + num_layers=28, + attention_head_dim=128, + num_attention_heads=32, + out_channels=16, + text_embed_dim=4096, + time_embed_dim=512, + condition_dim=256, + pos_embed_max_size=128, + ) + transformer.load_state_dict(converted_transformer_state_dict, strict=True) + if dtype is not None: + # Original checkpoint data type will be preserved + transformer = transformer.to(dtype=dtype) + + if args.vae_checkpoint_path is not None: + vae_config = { + "in_channels": 3, + "out_channels": 3, + "down_block_types": ("DownEncoderBlock2D",) * 4, + "up_block_types": ("UpDecoderBlock2D",) * 4, + "block_out_channels": (128, 512, 1024, 1024), + "layers_per_block": 3, + "act_fn": "silu", + "latent_channels": 16, + "norm_num_groups": 32, + "sample_size": 1024, + "scaling_factor": 1.0, + "force_upcast": True, + "use_quant_conv": False, + "use_post_quant_conv": False, + "mid_block_add_attention": False, + } + converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config) + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_state_dict, strict=True) + if dtype is not None: + vae = vae.to(dtype=dtype) + + text_encoder_id = "THUDM/glm-4-9b-hf" + tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id) + text_encoder = GlmForCausalLM.from_pretrained( + text_encoder_id, + cache_dir=args.text_encoder_cache_dir, + torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32, + ) + + for param in text_encoder.parameters(): + param.data = param.data.contiguous() + + scheduler = FlowMatchEulerDiscreteScheduler( + base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear" + ) + + pipe = CogView4Pipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + + # This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can + # save some memory used for model loading. + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub) + + +if __name__ == "__main__": + main(args) diff --git a/scripts/convert_cogview4_to_diffusers_megatron.py b/scripts/convert_cogview4_to_diffusers_megatron.py new file mode 100644 index 000000000000..de5354952493 --- /dev/null +++ b/scripts/convert_cogview4_to_diffusers_megatron.py @@ -0,0 +1,366 @@ +""" +Convert a CogView4 checkpoint from Megatron to the Diffusers format. + +Example usage: + python scripts/convert_cogview4_to_diffusers.py \ + --transformer_checkpoint_path 'your path/cogview4_6b/mp_rank_00/model_optim_rng.pt' \ + --vae_checkpoint_path 'your path/cogview4_6b/imagekl_ch16.pt' \ + --output_path "THUDM/CogView4-6B" \ + --dtype "bf16" + +Arguments: + --transformer_checkpoint_path: Path to Transformer state dict. + --vae_checkpoint_path: Path to VAE state dict. + --output_path: The path to save the converted model. + --push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`. + --text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used. + --dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered. + + Default is "bf16" because CogView4 uses bfloat16 for training. + +Note: You must provide either --transformer_checkpoint_path or --vae_checkpoint_path. +""" + +import argparse + +import torch +from tqdm import tqdm +from transformers import GlmForCausalLM, PreTrainedTokenizerFast + +from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler +from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--transformer_checkpoint_path", + default=None, + type=str, + help="Path to Megatron (not SAT) Transformer checkpoint, e.g., 'model_optim_rng.pt'.", +) +parser.add_argument( + "--vae_checkpoint_path", + default=None, + type=str, + help="(Optional) Path to VAE checkpoint, e.g., 'imagekl_ch16.pt'.", +) +parser.add_argument( + "--output_path", + required=True, + type=str, + help="Directory to save the final Diffusers format pipeline.", +) +parser.add_argument( + "--push_to_hub", + action="store_true", + default=False, + help="Whether to push the converted model to the HuggingFace Hub.", +) +parser.add_argument( + "--text_encoder_cache_dir", + type=str, + default=None, + help="Specify the cache directory for the text encoder.", +) +parser.add_argument( + "--dtype", + type=str, + default="bf16", + choices=["fp16", "bf16", "fp32"], + help="Data type to save the model in.", +) + +parser.add_argument( + "--num_layers", + type=int, + default=28, + help="Number of Transformer layers (e.g., 28, 48...).", +) +parser.add_argument( + "--num_heads", + type=int, + default=32, + help="Number of attention heads.", +) +parser.add_argument( + "--hidden_size", + type=int, + default=4096, + help="Transformer hidden dimension size.", +) +parser.add_argument( + "--attention_head_dim", + type=int, + default=128, + help="Dimension of each attention head.", +) +parser.add_argument( + "--time_embed_dim", + type=int, + default=512, + help="Dimension of time embeddings.", +) +parser.add_argument( + "--condition_dim", + type=int, + default=256, + help="Dimension of condition embeddings.", +) +parser.add_argument( + "--pos_embed_max_size", + type=int, + default=128, + help="Maximum size for positional embeddings.", +) + +args = parser.parse_args() + + +def swap_scale_shift(weight, dim): + """ + Swap the scale and shift components in the weight tensor. + + Args: + weight (torch.Tensor): The original weight tensor. + dim (int): The dimension along which to split. + + Returns: + torch.Tensor: The modified weight tensor with scale and shift swapped. + """ + shift, scale = weight.chunk(2, dim=dim) + new_weight = torch.cat([scale, shift], dim=dim) + return new_weight + + +def convert_megatron_transformer_checkpoint_to_diffusers( + ckpt_path: str, + num_layers: int, + num_heads: int, + hidden_size: int, +): + """ + Convert a Megatron Transformer checkpoint to Diffusers format. + + Args: + ckpt_path (str): Path to the Megatron Transformer checkpoint. + num_layers (int): Number of Transformer layers. + num_heads (int): Number of attention heads. + hidden_size (int): Hidden size of the Transformer. + + Returns: + dict: The converted state dictionary compatible with Diffusers. + """ + ckpt = torch.load(ckpt_path, map_location="cpu") + mega = ckpt["model"] + + new_state_dict = {} + + # Patch Embedding + new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64) + new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"] + new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"] + new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"] + + # Time Condition Embedding + new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = mega[ + "time_embedding.time_embed.0.weight" + ] + new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = mega["time_embedding.time_embed.0.bias"] + new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = mega[ + "time_embedding.time_embed.2.weight" + ] + new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = mega["time_embedding.time_embed.2.bias"] + + new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = mega[ + "label_embedding.label_embed.0.weight" + ] + new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = mega[ + "label_embedding.label_embed.0.bias" + ] + new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = mega[ + "label_embedding.label_embed.2.weight" + ] + new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = mega[ + "label_embedding.label_embed.2.bias" + ] + + # Convert each Transformer layer + for i in tqdm(range(num_layers), desc="Converting layers (Megatron->Diffusers)"): + block_prefix = f"transformer_blocks.{i}." + + # AdaLayerNorm + new_state_dict[block_prefix + "norm1.linear.weight"] = swap_scale_shift( + mega[f"decoder.layers.{i}.adaln.weight"], dim=0 + ) + new_state_dict[block_prefix + "norm1.linear.bias"] = swap_scale_shift( + mega[f"decoder.layers.{i}.adaln.bias"], dim=0 + ) + + # QKV + qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"] + qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"] + + # Reshape to match SAT logic + qkv_weight = qkv_weight.view(num_heads, 3, hidden_size // num_heads, hidden_size) + qkv_weight = qkv_weight.permute(1, 0, 2, 3).reshape(3 * hidden_size, hidden_size) + + qkv_bias = qkv_bias.view(num_heads, 3, hidden_size // num_heads) + qkv_bias = qkv_bias.permute(1, 0, 2).reshape(3 * hidden_size) + + # Assign to Diffusers keys + q, k, v = torch.chunk(qkv_weight, 3, dim=0) + qb, kb, vb = torch.chunk(qkv_bias, 3, dim=0) + + new_state_dict[block_prefix + "attn1.to_q.weight"] = q + new_state_dict[block_prefix + "attn1.to_q.bias"] = qb + new_state_dict[block_prefix + "attn1.to_k.weight"] = k + new_state_dict[block_prefix + "attn1.to_k.bias"] = kb + new_state_dict[block_prefix + "attn1.to_v.weight"] = v + new_state_dict[block_prefix + "attn1.to_v.bias"] = vb + + # Attention Output + new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[ + f"decoder.layers.{i}.self_attention.linear_proj.weight" + ].T + new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[ + f"decoder.layers.{i}.self_attention.linear_proj.bias" + ] + + # MLP + new_state_dict[block_prefix + "ff.net.0.proj.weight"] = mega[f"decoder.layers.{i}.mlp.linear_fc1.weight"] + new_state_dict[block_prefix + "ff.net.0.proj.bias"] = mega[f"decoder.layers.{i}.mlp.linear_fc1.bias"] + new_state_dict[block_prefix + "ff.net.2.weight"] = mega[f"decoder.layers.{i}.mlp.linear_fc2.weight"] + new_state_dict[block_prefix + "ff.net.2.bias"] = mega[f"decoder.layers.{i}.mlp.linear_fc2.bias"] + + # Final Layers + new_state_dict["norm_out.linear.weight"] = swap_scale_shift(mega["adaln_final.weight"], dim=0) + new_state_dict["norm_out.linear.bias"] = swap_scale_shift(mega["adaln_final.bias"], dim=0) + new_state_dict["proj_out.weight"] = mega["output_projector.weight"] + new_state_dict["proj_out.bias"] = mega["output_projector.bias"] + + return new_state_dict + + +def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config): + """ + Convert a CogView4 VAE checkpoint to Diffusers format. + + Args: + ckpt_path (str): Path to the VAE checkpoint. + vae_config (dict): Configuration dictionary for the VAE. + + Returns: + dict: The converted VAE state dictionary compatible with Diffusers. + """ + original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + return convert_ldm_vae_checkpoint(original_state_dict, vae_config) + + +def main(args): + """ + Main function to convert CogView4 checkpoints to Diffusers format. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + """ + # Determine the desired data type + if args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + elif args.dtype == "fp32": + dtype = torch.float32 + else: + raise ValueError(f"Unsupported dtype: {args.dtype}") + + transformer = None + vae = None + + # Convert Transformer checkpoint if provided + if args.transformer_checkpoint_path is not None: + converted_transformer_state_dict = convert_megatron_transformer_checkpoint_to_diffusers( + ckpt_path=args.transformer_checkpoint_path, + num_layers=args.num_layers, + num_heads=args.num_heads, + hidden_size=args.hidden_size, + ) + transformer = CogView4Transformer2DModel( + patch_size=2, + in_channels=16, + num_layers=args.num_layers, + attention_head_dim=args.attention_head_dim, + num_attention_heads=args.num_heads, + out_channels=16, + text_embed_dim=args.hidden_size, + time_embed_dim=args.time_embed_dim, + condition_dim=args.condition_dim, + pos_embed_max_size=args.pos_embed_max_size, + ) + + transformer.load_state_dict(converted_transformer_state_dict, strict=True) + + # Convert to the specified dtype + if dtype is not None: + transformer = transformer.to(dtype=dtype) + + # Convert VAE checkpoint if provided + if args.vae_checkpoint_path is not None: + vae_config = { + "in_channels": 3, + "out_channels": 3, + "down_block_types": ("DownEncoderBlock2D",) * 4, + "up_block_types": ("UpDecoderBlock2D",) * 4, + "block_out_channels": (128, 512, 1024, 1024), + "layers_per_block": 3, + "act_fn": "silu", + "latent_channels": 16, + "norm_num_groups": 32, + "sample_size": 1024, + "scaling_factor": 1.0, + "force_upcast": True, + "use_quant_conv": False, + "use_post_quant_conv": False, + "mid_block_add_attention": False, + } + converted_vae_state_dict = convert_cogview4_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config) + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_state_dict, strict=True) + if dtype is not None: + vae = vae.to(dtype=dtype) + + # Load the text encoder and tokenizer + text_encoder_id = "THUDM/glm-4-9b-hf" + tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id) + text_encoder = GlmForCausalLM.from_pretrained( + text_encoder_id, + cache_dir=args.text_encoder_cache_dir, + torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32, + ) + for param in text_encoder.parameters(): + param.data = param.data.contiguous() + + # Initialize the scheduler + scheduler = FlowMatchEulerDiscreteScheduler( + base_shift=0.25, max_shift=0.75, base_image_seq_len=256, use_dynamic_shifting=True, time_shift_type="linear" + ) + + # Create the pipeline + pipe = CogView4Pipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + ) + + # Save the converted pipeline + pipe.save_pretrained( + args.output_path, + safe_serialization=True, + max_shard_size="5GB", + push_to_hub=args.push_to_hub, + ) + + +if __name__ == "__main__": + main(args) diff --git a/setup.py b/setup.py index 0acdcbbb9c52..1da12e158b36 100644 --- a/setup.py +++ b/setup.py @@ -130,6 +130,7 @@ "regex!=2019.12.17", "requests", "tensorboard", + "tiktoken>=0.7.0", "torch>=1.4", "torchvision", "transformers>=4.41.2", @@ -226,6 +227,7 @@ def run(self): "safetensors", "sentencepiece", "scipy", + "tiktoken", "torchvision", "transformers", "phonemizer", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5d1c2f13b8e0..a9e7c823db41 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -101,6 +101,7 @@ "CacheMixin", "CogVideoXTransformer3DModel", "CogView3PlusTransformer2DModel", + "CogView4Transformer2DModel", "ConsisIDTransformer3DModel", "ConsistencyDecoderVAE", "ControlNetModel", @@ -287,6 +288,7 @@ "CogVideoXPipeline", "CogVideoXVideoToVideoPipeline", "CogView3PlusPipeline", + "CogView4Pipeline", "ConsisIDPipeline", "CycleDiffusionPipeline", "FluxControlImg2ImgPipeline", @@ -619,6 +621,7 @@ CacheMixin, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, + CogView4Transformer2DModel, ConsisIDTransformer3DModel, ConsistencyDecoderVAE, ControlNetModel, @@ -784,6 +787,7 @@ CogVideoXPipeline, CogVideoXVideoToVideoPipeline, CogView3PlusPipeline, + CogView4Pipeline, ConsisIDPipeline, CycleDiffusionPipeline, FluxControlImg2ImgPipeline, diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 7999368f1417..17d5da60347d 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -38,6 +38,7 @@ "regex": "regex!=2019.12.17", "requests": "requests", "tensorboard": "tensorboard", + "tiktoken": "tiktoken>=0.7.0", "torch": "torch>=1.4", "torchvision": "torchvision", "transformers": "transformers>=4.41.2", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 661f4ca6307a..853f149fe01c 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -70,6 +70,7 @@ _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] + _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] @@ -136,6 +137,7 @@ AuraFlowTransformer2DModel, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, + CogView4Transformer2DModel, ConsisIDTransformer3DModel, DiTTransformer2DModel, DualTransformer2DModel, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5d873baf8fbb..8bba5a82bc2f 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2825,9 +2825,7 @@ def __call__( hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + batch_size, sequence_length, _ = hidden_states.shape if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c42fbbc9f0a3..390b752abe15 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1199,7 +1199,7 @@ def apply_rotary_emb( x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: - # Used for Stable Audio and OmniGen + # Used for Stable Audio, OmniGen and CogView4 x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index f16f605a6cd7..f32c30ceff3c 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -18,6 +18,7 @@ from .transformer_2d import Transformer2DModel from .transformer_allegro import AllegroTransformer3DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel + from .transformer_cogview4 import CogView4Transformer2DModel from .transformer_flux import FluxTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py new file mode 100644 index 000000000000..f622791b572f --- /dev/null +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -0,0 +1,420 @@ +# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.attention import FeedForward +from ...models.attention_processor import Attention +from ...models.modeling_utils import ModelMixin +from ...models.normalization import AdaLayerNormContinuous +from ...utils import logging +from ..embeddings import CogView3CombinedTimestepSizeEmbeddings +from ..modeling_outputs import Transformer2DModelOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CogView4PatchEmbed(nn.Module): + def __init__( + self, + in_channels: int = 16, + hidden_size: int = 2560, + patch_size: int = 2, + text_hidden_size: int = 4096, + ): + super().__init__() + self.patch_size = patch_size + + self.proj = nn.Linear(in_channels * patch_size**2, hidden_size) + self.text_proj = nn.Linear(text_hidden_size, hidden_size) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, channel, height, width = hidden_states.shape + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + + hidden_states = hidden_states.reshape( + batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size + ) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + hidden_states = self.proj(hidden_states) + encoder_hidden_states = self.text_proj(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + +class CogView4AdaLayerNormZero(nn.Module): + def __init__(self, embedding_dim: int, dim: int) -> None: + super().__init__() + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) + + def forward( + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states = self.norm(hidden_states) + norm_encoder_hidden_states = self.norm_context(encoder_hidden_states) + + emb = self.linear(temb) + ( + shift_msa, + c_shift_msa, + scale_msa, + c_scale_msa, + gate_msa, + c_gate_msa, + shift_mlp, + c_shift_mlp, + scale_mlp, + c_scale_mlp, + gate_mlp, + c_gate_mlp, + ) = emb.chunk(12, dim=1) + + hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1) + + return ( + hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) + + +class CogView4AttnProcessor: + """ + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + query[:, :, text_seq_length:, :] = apply_rotary_emb( + query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + key[:, :, text_seq_length:, :] = apply_rotary_emb( + key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 + ) + + # 4. Attention + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # 5. Output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +class CogView4TransformerBlock(nn.Module): + def __init__( + self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512 + ) -> None: + super().__init__() + + # 1. Attention + self.norm1 = CogView4AdaLayerNormZero(time_embed_dim, dim) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + out_dim=dim, + bias=True, + qk_norm="layer_norm", + elementwise_affine=False, + eps=1e-5, + processor=CogView4AttnProcessor(), + ) + + # 2. Feedforward + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # 1. Timestep conditioning + ( + norm_hidden_states, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + norm_encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1(hidden_states, encoder_hidden_states, temb) + + # 2. Attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) + + # 3. Feedforward + norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * ( + 1 + c_scale_mlp.unsqueeze(1) + ) + c_shift_mlp.unsqueeze(1) + + ff_output = self.ff(norm_hidden_states) + ff_output_context = self.ff(norm_encoder_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) + + return hidden_states, encoder_hidden_states + + +class CogView4RotaryPosEmbed(nn.Module): + def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None: + super().__init__() + + self.patch_size = patch_size + self.rope_axes_dim = rope_axes_dim + + dim_h, dim_w = dim // 2, dim // 2 + h_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)) + w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)) + h_seq = torch.arange(self.rope_axes_dim[0]) + w_seq = torch.arange(self.rope_axes_dim[1]) + self.freqs_h = torch.outer(h_seq, h_inv_freq) + self.freqs_w = torch.outer(w_seq, w_inv_freq) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, num_channels, height, width = hidden_states.shape + height, width = height // self.patch_size, width // self.patch_size + + h_idx = torch.arange(height) + w_idx = torch.arange(width) + inner_h_idx = h_idx * self.rope_axes_dim[0] // height + inner_w_idx = w_idx * self.rope_axes_dim[1] // width + + self.freqs_h = self.freqs_h.to(hidden_states.device) + self.freqs_w = self.freqs_w.to(hidden_states.device) + freqs_h = self.freqs_h[inner_h_idx] + freqs_w = self.freqs_w[inner_w_idx] + + # Create position matrices for height and width + # [height, 1, dim//4] and [1, width, dim//4] + freqs_h = freqs_h.unsqueeze(1) + freqs_w = freqs_w.unsqueeze(0) + # Broadcast freqs_h and freqs_w to [height, width, dim//4] + freqs_h = freqs_h.expand(height, width, -1) + freqs_w = freqs_w.expand(height, width, -1) + + # Concatenate along last dimension to get [height, width, dim//2] + freqs = torch.cat([freqs_h, freqs_w], dim=-1) + freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim] + freqs = freqs.reshape(height * width, -1) + return (freqs.cos(), freqs.sin()) + + +class CogView4Transformer2DModel(ModelMixin, ConfigMixin): + r""" + Args: + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + attention_head_dim (`int`, defaults to `40`): + The number of channels in each head. + num_attention_heads (`int`, defaults to `64`): + The number of heads to use for multi-head attention. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + condition_dim (`int`, defaults to `256`): + The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size, + crop_coords). + pos_embed_max_size (`int`, defaults to `128`): + The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added + to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128 + means that the maximum supported height and width for image generation is `128 * vae_scale_factor * + patch_size => 128 * 8 * 2 => 2048`. + sample_size (`int`, defaults to `128`): + The base resolution of input latents. If height/width is not provided during generation, this value is used + to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024` + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["CogView4TransformerBlock", "CogView4PatchEmbed", "CogView4PatchEmbed"] + _skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: int = 16, + num_layers: int = 30, + attention_head_dim: int = 40, + num_attention_heads: int = 64, + text_embed_dim: int = 4096, + time_embed_dim: int = 512, + condition_dim: int = 256, + pos_embed_max_size: int = 128, + sample_size: int = 128, + rope_axes_dim: Tuple[int, int] = (256, 256), + ): + super().__init__() + + # CogView4 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords + # Each of these are sincos embeddings of shape 2 * condition_dim + pooled_projection_dim = 3 * 2 * condition_dim + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels + + # 1. RoPE + self.rope = CogView4RotaryPosEmbed(attention_head_dim, patch_size, rope_axes_dim, theta=10000.0) + + # 2. Patch & Text-timestep embedding + self.patch_embed = CogView4PatchEmbed(in_channels, inner_dim, patch_size, text_embed_dim) + + self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings( + embedding_dim=time_embed_dim, + condition_dim=condition_dim, + pooled_projection_dim=pooled_projection_dim, + timesteps_dim=inner_dim, + ) + + # 3. Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + CogView4TransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) + for _ in range(num_layers) + ] + ) + + # 4. Output projection + self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + original_size: torch.Tensor, + target_size: torch.Tensor, + crop_coords: torch.Tensor, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + batch_size, num_channels, height, width = hidden_states.shape + + # 1. RoPE + image_rotary_emb = self.rope(hidden_states) + + # 2. Patch & Timestep embeddings + p = self.config.patch_size + post_patch_height = height // p + post_patch_width = width // p + + hidden_states, encoder_hidden_states = self.patch_embed(hidden_states, encoder_hidden_states) + + temb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype) + temb = F.silu(temb) + + # 3. Transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, image_rotary_emb + ) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) + output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 84e193f681d6..49041086f535 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -154,6 +154,7 @@ "CogVideoXFunControlPipeline", ] _import_structure["cogview3"] = ["CogView3PlusPipeline"] + _import_structure["cogview4"] = ["CogView4Pipeline"] _import_structure["consisid"] = ["ConsisIDPipeline"] _import_structure["controlnet"].extend( [ @@ -499,6 +500,7 @@ CogVideoXVideoToVideoPipeline, ) from .cogview3 import CogView3PlusPipeline + from .cogview4 import CogView4Pipeline from .consisid import ConsisIDPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 6066836e7a05..1c38f83a7ef3 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -22,6 +22,7 @@ from ..utils import is_sentencepiece_available from .aura_flow import AuraFlowPipeline from .cogview3 import CogView3PlusPipeline +from .cogview4 import CogView4Pipeline from .controlnet import ( StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, @@ -138,6 +139,7 @@ ("lumina", LuminaText2ImgPipeline), ("lumina2", Lumina2Text2ImgPipeline), ("cogview3", CogView3PlusPipeline), + ("cogview4", CogView4Pipeline), ] ) diff --git a/src/diffusers/pipelines/cogview4/__init__.py b/src/diffusers/pipelines/cogview4/__init__.py new file mode 100644 index 000000000000..5a535b3feb4b --- /dev/null +++ b/src/diffusers/pipelines/cogview4/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["CogView4PlusPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_cogview4"] = ["CogView4Pipeline"] +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_cogview4 import CogView4Pipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py new file mode 100644 index 000000000000..097d1b6aed41 --- /dev/null +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -0,0 +1,665 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import AutoTokenizer, GlmModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL, CogView4Transformer2DModel +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from .pipeline_output import CogView4PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogView4Pipeline + + >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + >>> image.save("output.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + base_shift: float = 0.25, + max_shift: float = 0.75, +) -> float: + m = (image_seq_len / base_seq_len) ** 0.5 + mu = m * max_shift + base_shift + return mu + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + + if timesteps is not None and sigmas is not None: + if not accepts_timesteps and not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep or sigma schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is not None and sigmas is None: + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif timesteps is None and sigmas is not None: + if not accepts_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogView4Pipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using CogView4. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogView4 uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogView4Transformer2DModel`]): + A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: GlmModel, + vae: AutoencoderKL, + transformer: CogView4Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _get_glm_embeds( + self, + prompt: Union[str, List[str]] = None, + max_sequence_length: int = 1024, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer( + prompt, + padding="longest", # not use max length + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + current_length = text_input_ids.shape[1] + pad_length = (16 - (current_length % 16)) % 16 + if pad_length > 0: + pad_ids = torch.full( + (text_input_ids.shape[0], pad_length), + fill_value=self.tokenizer.pad_token_id, + dtype=text_input_ids.dtype, + device=text_input_ids.device, + ) + text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1) + prompt_embeds = self.text_encoder( + text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True + ).hidden_states[-2] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 1024, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + max_sequence_length (`int`, defaults to `1024`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_glm_embeds(prompt, max_sequence_length, device, dtype) + + seq_len = prompt_embeds.size(1) + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_glm_embeds(negative_prompt, max_sequence_length, device, dtype) + + seq_len = negative_prompt_embeds.size(1) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if latents is not None: + return latents.to(device) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 1024, + ) -> Union[CogView4PipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. If not provided, it is set to 1024. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. If not provided it is set to 1024. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to `1`): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `224`): + Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. + + Examples: + + Returns: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`: + [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = (height, width) + + # Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + self.do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Prepare latents + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Prepare additional timestep conditions + original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device) + target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device) + crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device) + + original_size = original_size.repeat(batch_size * num_images_per_prompt, 1) + target_size = target_size.repeat(batch_size * num_images_per_prompt, 1) + crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1) + + # Prepare timesteps + image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // ( + self.transformer.config.patch_size**2 + ) + timesteps = ( + np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps) + if timesteps is None + else np.array(timesteps) + ) + timesteps = timesteps.astype(np.int64).astype(np.float32) + sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("base_shift", 0.25), + self.scheduler.config.get("max_shift", 0.75), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu + ) + self._num_timesteps = len(timesteps) + + # Denoising loop + transformer_dtype = self.transformer.dtype + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = latents.to(transformer_dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + return_dict=False, + )[0] + + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + image = self.vae.decode(latents, return_dict=False, generator=generator)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return CogView4PipelineOutput(images=image) diff --git a/src/diffusers/pipelines/cogview4/pipeline_output.py b/src/diffusers/pipelines/cogview4/pipeline_output.py new file mode 100644 index 000000000000..4efec1310845 --- /dev/null +++ b/src/diffusers/pipelines/cogview4/pipeline_output.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class CogView4PipelineOutput(BaseOutput): + """ + Output class for CogView3 pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 5f17f044cc69..e3bff7582cd9 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -78,6 +78,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): Whether to use exponential sigmas for step sizes in the noise schedule during sampling. use_beta_sigmas (`bool`, defaults to False): Whether to use beta sigmas for step sizes in the noise schedule during sampling. + time_shift_type (`str`, defaults to "exponential"): + The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". """ _compatibles = [] @@ -88,7 +90,7 @@ def __init__( self, num_train_timesteps: int = 1000, shift: float = 1.0, - use_dynamic_shifting=False, + use_dynamic_shifting: bool = False, base_shift: Optional[float] = 0.5, max_shift: Optional[float] = 1.15, base_image_seq_len: Optional[int] = 256, @@ -98,6 +100,7 @@ def __init__( use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, + time_shift_type: str = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -105,6 +108,9 @@ def __init__( raise ValueError( "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." ) + if time_shift_type not in {"exponential", "linear"}: + raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.") + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) @@ -211,7 +217,10 @@ def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps def time_shift(self, mu: float, sigma: float, t: torch.Tensor): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: r""" @@ -236,54 +245,94 @@ def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: def set_timesteps( self, - num_inference_steps: int = None, + num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, sigmas: Optional[List[float]] = None, mu: Optional[float] = None, + timesteps: Optional[List[float]] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. + mu (`float`, *optional*): + Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep + shifting. + timesteps (`List[float]`, *optional*): + Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed + automatically. """ if self.config.use_dynamic_shifting and mu is None: - raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`") + + if sigmas is not None and timesteps is not None: + if len(sigmas) != len(timesteps): + raise ValueError("`sigmas` and `timesteps` should have the same length") + + if num_inference_steps is not None: + if (sigmas is not None and len(sigmas) != num_inference_steps) or ( + timesteps is not None and len(timesteps) != num_inference_steps + ): + raise ValueError( + "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided" + ) + else: + num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps) - if sigmas is None: - timesteps = np.linspace( - self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps - ) + self.num_inference_steps = num_inference_steps + + # 1. Prepare default sigmas + is_timesteps_provided = timesteps is not None + if is_timesteps_provided: + timesteps = np.array(timesteps).astype(np.float32) + + if sigmas is None: + if timesteps is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) sigmas = timesteps / self.config.num_train_timesteps else: sigmas = np.array(sigmas).astype(np.float32) num_inference_steps = len(sigmas) - self.num_inference_steps = num_inference_steps + # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of + # "exponential" or "linear" type is applied if self.config.use_dynamic_shifting: sigmas = self.time_shift(mu, 1.0, sigmas) else: sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value if self.config.shift_terminal: sigmas = self.stretch_shift_to_terminal(sigmas) + # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - elif self.config.use_beta_sigmas: sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + # 5. Convert sigmas and timesteps to tensors and move to specified device sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) - timesteps = sigmas * self.config.num_train_timesteps + if not is_timesteps_provided: + timesteps = sigmas * self.config.num_train_timesteps + else: + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) + # 6. Append the terminal sigma value. + # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the + # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi if self.config.invert_sigmas: sigmas = 1.0 - sigmas timesteps = sigmas * self.config.num_train_timesteps @@ -291,7 +340,7 @@ def set_timesteps( else: sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) - self.timesteps = timesteps.to(device=device) + self.timesteps = timesteps self.sigmas = sigmas self._step_index = None self._begin_index = None @@ -474,5 +523,11 @@ def _convert_to_beta( ) return sigmas + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + def __len__(self): return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 57198d9409f4..9dd1e690742f 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -276,6 +276,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class CogView4Transformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ConsisIDTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 02bef4aba0a5..c853cf8faa55 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -362,6 +362,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class CogView4Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class ConsisIDPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_cogview4.py b/tests/models/transformers/test_models_transformer_cogview4.py new file mode 100644 index 000000000000..e311ce77ea50 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_cogview4.py @@ -0,0 +1,83 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import CogView4Transformer2DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = CogView4Transformer2DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + height = 8 + width = 8 + embedding_dim = 8 + sequence_length = 8 + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device) + target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device) + crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "original_size": original_size, + "target_size": target_size, + "crop_coords": crop_coords, + } + + @property + def input_shape(self): + return (4, 8, 8) + + @property + def output_shape(self): + return (4, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": 2, + "in_channels": 4, + "num_layers": 2, + "attention_head_dim": 4, + "num_attention_heads": 4, + "out_channels": 4, + "text_embed_dim": 8, + "time_embed_dim": 8, + "condition_dim": 4, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"CogView4Transformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/cogview4/__init__.py b/tests/pipelines/cogview4/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/cogview4/test_cogview4.py b/tests/pipelines/cogview4/test_cogview4.py new file mode 100644 index 000000000000..2a97a0799d76 --- /dev/null +++ b/tests/pipelines/cogview4/test_cogview4.py @@ -0,0 +1,234 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, GlmConfig, GlmForCausalLM + +from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class CogView4PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = CogView4Pipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + + supports_dduf = False + test_xformers_attention = False + test_layerwise_casting = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = CogView4Transformer2DModel( + patch_size=2, + in_channels=4, + num_layers=2, + attention_head_dim=4, + num_attention_heads=4, + out_channels=4, + text_embed_dim=32, + time_embed_dim=8, + condition_dim=4, + ) + + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler( + base_shift=0.25, + max_shift=0.75, + base_image_seq_len=256, + use_dynamic_shifting=True, + time_shift_type="linear", + ) + + torch.manual_seed(0) + text_encoder_config = GlmConfig( + hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8 + ) + text_encoder = GlmForCausalLM(text_encoder_config) + # TODO(aryan): change this to THUDM/CogView4 once released + tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs)[0] + generated_image = image[0] + + self.assertEqual(generated_image.shape, (3, 16, 16)) + expected_image = torch.randn(3, 16, 16) + max_diff = np.abs(generated_image - expected_image).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + )