Skip to content

Update Stable Cascade Conversion Scripts #7271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 159 additions & 156 deletions scripts/convert_stable_cascade.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
from contextlib import nullcontext

import accelerate
import torch
from safetensors.torch import load_file
from transformers import (
Expand All @@ -18,23 +18,56 @@
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available


if is_accelerate_available():
from accelerate import init_empty_weights

parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights")
parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights")
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c")
parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b")
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to")
parser.add_argument(
"--prior_output_path", default="stable-cascade-prior", type=str, help="Hub organization to save the pipelines to"
)
parser.add_argument(
"--decoder_output_path",
type=str,
default="stable-cascade-decoder",
help="Hub organization to save the pipelines to",
)
parser.add_argument(
"--combined_output_path",
type=str,
default="stable-cascade-combined",
help="Hub organization to save the pipelines to",
)
parser.add_argument("--save_combined", action="store_true")
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")

args = parser.parse_args()

if args.skip_stage_b and args.skip_stage_c:
raise ValueError("At least one stage should be converted")
if (args.skip_stage_b or args.skip_stage_c) and args.save_combined:
raise ValueError("Cannot skip stages when creating a combined pipeline")

model_path = args.model_path

device = "cpu"
if args.variant == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32

# set paths to model weights
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
Expand All @@ -52,164 +85,134 @@
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")

# Prior
if args.use_safetensors:
orig_state_dict = load_file(prior_checkpoint_path, device=device)
else:
orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)

state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]


with accelerate.init_empty_weights():
prior_model = StableCascadeUNet(
in_channels=16,
out_channels=16,
timestep_ratio_embedding_dim=64,
patch_size=1,
conditioning_dim=2048,
block_out_channels=[2048, 2048],
num_attention_heads=[32, 32],
down_num_layers_per_block=[8, 24],
up_num_layers_per_block=[24, 8],
down_blocks_repeat_mappers=[1, 1],
up_blocks_repeat_mappers=[1, 1],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_in_channels=1280,
clip_text_pooled_in_channels=1280,
clip_image_in_channels=768,
clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca", "crp"],
switch_level=[False],
)
load_model_dict_into_meta(prior_model, state_dict)

# scheduler for prior and decoder
scheduler = DDPMWuerstchenScheduler()
ctx = init_empty_weights if is_accelerate_available() else nullcontext

# Prior pipeline
prior_pipeline = StableCascadePriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub)

# Decoder
if args.use_safetensors:
orig_state_dict = load_file(decoder_checkpoint_path, device=device)
else:
orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)

state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
# rename clip_mapper to clip_txt_pooled_mapper
elif key.endswith("clip_mapper.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
elif key.endswith("clip_mapper.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
if not args.skip_stage_c:
# Prior
if args.use_safetensors:
prior_orig_state_dict = load_file(prior_checkpoint_path, device=device)
else:
state_dict[key] = orig_state_dict[key]

with accelerate.init_empty_weights():
decoder = StableCascadeUNet(
in_channels=4,
out_channels=4,
timestep_ratio_embedding_dim=64,
patch_size=2,
conditioning_dim=1280,
block_out_channels=[320, 640, 1280, 1280],
down_num_layers_per_block=[2, 6, 28, 6],
up_num_layers_per_block=[6, 28, 6, 2],
down_blocks_repeat_mappers=[1, 1, 1, 1],
up_blocks_repeat_mappers=[3, 3, 2, 2],
num_attention_heads=[0, 0, 20, 20],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_pooled_in_channels=1280,
clip_seq=4,
effnet_in_channels=16,
pixel_mapper_in_channels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca"],
prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)

prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict)

with ctx():
prior_model = StableCascadeUNet(
in_channels=16,
out_channels=16,
timestep_ratio_embedding_dim=64,
patch_size=1,
conditioning_dim=2048,
block_out_channels=[2048, 2048],
num_attention_heads=[32, 32],
down_num_layers_per_block=[8, 24],
up_num_layers_per_block=[24, 8],
down_blocks_repeat_mappers=[1, 1],
up_blocks_repeat_mappers=[1, 1],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_in_channels=1280,
clip_text_pooled_in_channels=1280,
clip_image_in_channels=768,
clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca", "crp"],
switch_level=[False],
)
if is_accelerate_available():
load_model_dict_into_meta(prior_model, prior_state_dict)
else:
prior_model.load_state_dict(prior_state_dict)

# Prior pipeline
prior_pipeline = StableCascadePriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.to(dtype).save_pretrained(
args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
load_model_dict_into_meta(decoder, state_dict)

# VQGAN from Wuerstchen-V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
if not args.skip_stage_b:
# Decoder
if args.use_safetensors:
decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device)
else:
decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)

decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict)
with ctx():
decoder = StableCascadeUNet(
in_channels=4,
out_channels=4,
timestep_ratio_embedding_dim=64,
patch_size=2,
conditioning_dim=1280,
block_out_channels=[320, 640, 1280, 1280],
down_num_layers_per_block=[2, 6, 28, 6],
up_num_layers_per_block=[6, 28, 6, 2],
down_blocks_repeat_mappers=[1, 1, 1, 1],
up_blocks_repeat_mappers=[3, 3, 2, 2],
num_attention_heads=[0, 0, 20, 20],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_pooled_in_channels=1280,
clip_seq=4,
effnet_in_channels=16,
pixel_mapper_in_channels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca"],
)

if is_accelerate_available():
load_model_dict_into_meta(decoder, decoder_state_dict)
else:
decoder.load_state_dict(decoder_state_dict)

# Decoder pipeline
decoder_pipeline = StableCascadeDecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub)
# VQGAN from Wuerstchen-V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")

# Stable Cascade combined pipeline
stable_cascade_pipeline = StableCascadeCombinedPipeline(
# Decoder
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
prior_prior=prior_model,
prior_scheduler=scheduler,
prior_image_encoder=image_encoder,
prior_feature_extractor=feature_extractor,
)
stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub)
# Decoder pipeline
decoder_pipeline = StableCascadeDecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.to(dtype).save_pretrained(
args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)

if args.save_combined:
# Stable Cascade combined pipeline
stable_cascade_pipeline = StableCascadeCombinedPipeline(
# Decoder
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
prior_prior=prior_model,
prior_scheduler=scheduler,
prior_image_encoder=image_encoder,
prior_feature_extractor=feature_extractor,
)
stable_cascade_pipeline.to(dtype).save_pretrained(
args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant
)
Loading