-
Notifications
You must be signed in to change notification settings - Fork 6k
[Pipiline] Wuerstchen v3 aka Stable Cascasde pipeline #6487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
123 commits
Select commit
Hold shift + click to select a range
6185da3
initial diffNext v3
kashif 6fd8639
move to v3 folder
kashif 86e2bcd
imports
kashif e77db10
dry up the unets
kashif 644dc5d
Merge branch 'main' into wuerstchen-v3
kashif 1380b95
no switch_level
kashif 2bca122
fix init
kashif d4d0bc1
add switch_level tp config
kashif 0db9e4d
Fixed some things
dome272 87e5577
Added pooled text embeddings
dome272 38f9f35
Initial work on adding image encoder
dome272 dc3f47e
changes from @dome272
kashif 5c6635f
Stuff for the image encoder processing and variable naming in decoder
dome272 3d41b2a
fix arg name
kashif 2012e71
inference fixes
dome272 add164a
inference fixes
dome272 f6035c6
Merge branch 'main' into wuerstchen-v3
kashif edbd76b
default TimestepBlock without conds
kashif c5326fa
c_skip=0 by default
kashif 228f98c
fix bfloat16 to cpu
kashif b1e6db3
use config
kashif 0fb4bf8
undo temp change
kashif 834baba
fix gen_c_embeddings args
kashif fc361d2
change text encoding
dome272 7632707
text encoding
dome272 bef887a
undo print
kashif 0816469
Merge branch 'main' into wuerstchen-v3
kashif b1413d5
undo .gitignore change
kashif ae5967b
Allow WuerstchenV3PriorPipeline to use the base DDPM & DDIM schedulers
pabloppp 979ea12
use WuerstchenV3Unet in both pipelines
kashif dc24bb4
Merge branch 'main' into wuerstchen-v3
kashif 966cdbc
fix imports
kashif d9a71df
initial failing tests
kashif af02b68
cleanup
kashif e962671
use scheduler.timesterps
kashif a1ecef2
some fixes to the tests, still not fully working
pabloppp 331d0d3
fix tests
kashif 7452985
fix prior tests
kashif c0bb4ca
add dropout to the model_kwargs
kashif e01bc49
more tests passing
kashif 17fed8c
update expected_slice
kashif 733ec02
initial rename
kashif 021c3e2
rename tests
kashif b2c615f
rename class names
kashif 3d5328e
make fix-copies
kashif 33a1af8
initial docs
kashif a7040a2
autodocs
kashif 8882633
typos
kashif e63a312
fix arg docs
kashif cdeb5da
add text_encoder info
kashif 72b87e7
combined pipeline has optional image arg
kashif d929cdf
Merge branch 'main' into wuerstchen-v3
sayakpaul c883cb2
fix documentation
sayakpaul 66a17e1
Merge branch 'main' into wuerstchen-v3
kashif a3dc213
Merge branch 'main' into wuerstchen-v3
kashif 33b70f4
Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade…
kashif cc10c29
Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade…
kashif 6f5ed3d
Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade…
kashif bf3a972
Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade…
kashif 5634ef3
Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
kashif 3cf4c1b
Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade…
kashif b5e2ca9
use self.config
kashif 9b525fd
Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade…
kashif 60efc49
c_in -> in_channels
kashif cbd0775
removed kwargs from unet's forward
kashif c1f72e3
Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
kashif 7cb3838
Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
kashif b3a80f7
remove older callback api
kashif 519805f
removed kwargs and fixed decoder guidance > 1
kashif 7698bf6
decoder takes emeds
kashif 88633a9
check and use image_embeds
kashif d68207b
fixed all but one decoder test
kashif 143df09
fix decoder tests
kashif 2483df2
Merge branch 'main' into wuerstchen-v3
kashif f4a788b
Merge branch 'main' into wuerstchen-v3
kashif 169db20
update callback api
kashif 3cb0ec1
fix some more combined tests
kashif 84f4f3d
push combined pipeline
kashif 4f69a51
initial docs
kashif 7dcbdc6
fix doc_string
kashif 4f5dffb
update combined api
kashif 4753b99
no test_callback_inputs test for combined pipeline
kashif adec75f
add optional components
kashif 2e877d2
fix ordering of components
kashif e956f3e
fix combined tests
kashif f18ff23
update convert script
kashif 3ff5120
Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade…
kashif 979fed0
Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade…
kashif 72cf605
Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade…
kashif b2e0f06
fix imports
kashif 4c33b8a
move effnet out of deniosing loop
kashif 9785210
prompt_embeds_pooled only when doing guidance
kashif 25ecc81
Fix repeat shape
99991 1b171b6
Merge pull request #2 from 99991/wuerstchen-v3
kashif 4914e04
move StableCascadeUnet to models/unets/
kashif 8c2e479
more descriptive names
kashif 2d1f438
Merge branch 'main' into wuerstchen-v3
kashif 871387e
converted when numpy()
kashif 85fb15c
StableCascadePriorPipelineOutput docs
kashif 72249af
rename StableCascadeUNet
kashif 6767b29
add slow tests
kashif cb7f47c
fix slow tests
kashif 7ff8828
Merge branch 'main' into wuerstchen-v3
kashif 748ab08
update
DN6 3ad7516
update
DN6 e7434ff
updated model_path
kashif ac716ab
add args for weights
kashif 13e9812
set push_to_hub to false
kashif b6d3b6f
update
DN6 a07623f
update
DN6 a487d16
Merge branch 'wuerstchen-v3' of https://github.com/kashif/diffusers i…
DN6 c6a5537
update
DN6 e505de1
update
DN6 a2a5060
update
DN6 3326dee
update
DN6 11eac5f
update
DN6 2c226cd
update
DN6 d3e8cef
update
DN6 df5ed03
update
DN6 8e74e09
update
DN6 c1cd769
update
DN6 ceedcc4
update
DN6 8dd88f0
update
DN6 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Run inside root directory of official source code: https://github.com/dome272/wuerstchen/ | ||
import os | ||
|
||
import torch | ||
from transformers import AutoTokenizer, CLIPTextModel, CLIPVisionModelWithProjection | ||
# from vqgan import VQModel | ||
|
||
from diffusers import ( | ||
DDPMWuerstchenScheduler, | ||
WuerstchenV3CombinedPipeline, | ||
WuerstchenV3DecoderPipeline, | ||
WuerstchenV3PriorPipeline, | ||
) | ||
from diffusers.pipelines.wuerstchen import PaellaVQModel | ||
from diffusers.pipelines.wuerstchen3 import WuerstchenV3DiffNeXt, WuerstchenV3Prior | ||
|
||
|
||
model_path = "../Wuerstchen/" | ||
device = "cpu" | ||
|
||
# paella_vqmodel = VQModel() | ||
# state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"] | ||
# paella_vqmodel.load_state_dict(state_dict) | ||
|
||
# state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"] | ||
# state_dict.pop("vquantizer.codebook.weight") | ||
# vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent) | ||
# vqmodel.load_state_dict(state_dict) | ||
|
||
# # Clip Text encoder and tokenizer | ||
# text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") | ||
# tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") | ||
|
||
# # Generator | ||
# clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14").to("cpu") | ||
|
||
orig_state_dict = torch.load(os.path.join(model_path, "base_120k.pt"), map_location=device) | ||
state_dict = {} | ||
for key in orig_state_dict.keys(): | ||
if key.endswith("in_proj_weight"): | ||
weights = orig_state_dict[key].chunk(3, 0) | ||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] | ||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] | ||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] | ||
elif key.endswith("in_proj_bias"): | ||
weights = orig_state_dict[key].chunk(3, 0) | ||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] | ||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] | ||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] | ||
elif key.endswith("out_proj.weight"): | ||
weights = orig_state_dict[key] | ||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights | ||
elif key.endswith("out_proj.bias"): | ||
weights = orig_state_dict[key] | ||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights | ||
# rename clip_mapper to clip_txt_pooled_mapper | ||
elif key.endswith("clip_mapper.weight"): | ||
weights = orig_state_dict[key] | ||
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights | ||
elif key.endswith("clip_mapper.bias"): | ||
weights = orig_state_dict[key] | ||
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights | ||
else: | ||
state_dict[key] = orig_state_dict[key] | ||
decoder = WuerstchenV3DiffNeXt().to(device) | ||
decoder.load_state_dict(state_dict) | ||
|
||
|
||
# Prior | ||
orig_state_dict = torch.load(os.path.join(model_path, "v1.pt"), map_location=device) | ||
state_dict = {} | ||
for key in orig_state_dict.keys(): | ||
if key.endswith("in_proj_weight"): | ||
weights = orig_state_dict[key].chunk(3, 0) | ||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] | ||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] | ||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] | ||
elif key.endswith("in_proj_bias"): | ||
weights = orig_state_dict[key].chunk(3, 0) | ||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] | ||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] | ||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] | ||
elif key.endswith("out_proj.weight"): | ||
weights = orig_state_dict[key] | ||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights | ||
elif key.endswith("out_proj.bias"): | ||
weights = orig_state_dict[key] | ||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights | ||
else: | ||
state_dict[key] = orig_state_dict[key] | ||
prior_model = WuerstchenV3Prior().to(device) | ||
prior_model.load_state_dict(state_dict) | ||
|
||
import pdb | ||
pdb.set_trace() | ||
|
||
# # scheduler | ||
# scheduler = DDPMWuerstchenScheduler() | ||
# | ||
# # Prior pipeline | ||
# prior_pipeline = WuerstchenPriorPipeline( | ||
# prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler | ||
# ) | ||
# | ||
# prior_pipeline.save_pretrained("warp-ai/wuerstchen-prior") | ||
# | ||
# decoder_pipeline = WuerstchenDecoderPipeline( | ||
# text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=decoder, scheduler=scheduler | ||
# ) | ||
# decoder_pipeline.save_pretrained("warp-ai/wuerstchen") | ||
# | ||
# # Wuerstchen pipeline | ||
# wuerstchen_pipeline = WuerstchenCombinedPipeline( | ||
# # Decoder | ||
# text_encoder=gen_text_encoder, | ||
# tokenizer=gen_tokenizer, | ||
# decoder=decoder, | ||
# scheduler=scheduler, | ||
# vqgan=vqmodel, | ||
# # Prior | ||
# prior_tokenizer=tokenizer, | ||
# prior_text_encoder=text_encoder, | ||
# prior=prior_model, | ||
# prior_scheduler=scheduler, | ||
# ) | ||
# wuerstchen_pipeline.save_pretrained("warp-ai/WuerstchenCombinedPipeline") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import TYPE_CHECKING | ||
|
||
from ...utils import ( | ||
DIFFUSERS_SLOW_IMPORT, | ||
OptionalDependencyNotAvailable, | ||
_LazyModule, | ||
get_objects_from_module, | ||
is_torch_available, | ||
is_transformers_available, | ||
) | ||
|
||
|
||
_dummy_objects = {} | ||
_import_structure = {} | ||
|
||
try: | ||
if not (is_transformers_available() and is_torch_available()): | ||
raise OptionalDependencyNotAvailable() | ||
except OptionalDependencyNotAvailable: | ||
from ...utils import dummy_torch_and_transformers_objects | ||
|
||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) | ||
else: | ||
_import_structure["modeling_wuerstchen3_diffnext"] = ["WuerstchenV3DiffNeXt"] | ||
_import_structure["modeling_wuerstchen3_prior"] = ["WuerstchenV3Prior"] | ||
_import_structure["pipeline_wuerstchen3"] = ["WuerstchenV3DecoderPipeline"] | ||
_import_structure["pipeline_wuerstchen3_combined"] = ["WuerstchenV3CombinedPipeline"] | ||
_import_structure["pipeline_wuerstchen3_prior"] = ["WuerstchenV3PriorPipeline"] | ||
|
||
|
||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: | ||
try: | ||
if not (is_transformers_available() and is_torch_available()): | ||
raise OptionalDependencyNotAvailable() | ||
except OptionalDependencyNotAvailable: | ||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 | ||
else: | ||
from .modeling_wuerstchen3_diffnext import WuerstchenV3DiffNeXt | ||
from .modeling_wuerstchen3_prior import WuerstchenV3Prior | ||
from .pipeline_wuerstchen3 import WuerstchenV3DecoderPipeline | ||
from .pipeline_wuerstchen3_combined import WuerstchenV3CombinedPipeline | ||
from .pipeline_wuerstchen3_prior import WuerstchenV3PriorPipeline | ||
else: | ||
import sys | ||
|
||
sys.modules[__name__] = _LazyModule( | ||
__name__, | ||
globals()["__file__"], | ||
_import_structure, | ||
module_spec=__spec__, | ||
) | ||
|
||
for name, value in _dummy_objects.items(): | ||
setattr(sys.modules[__name__], name, value) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.